Add configurable default LLM provider and model preferences

Co-authored-by: mblanke <9078342+mblanke@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-03 17:39:37 +00:00
parent c4eaf1718a
commit 91b4697403
4 changed files with 138 additions and 28 deletions

View File

@@ -7,3 +7,10 @@ ANTHROPIC_API_KEY=
# Ollama Configuration
OLLAMA_BASE_URL=http://ollama:11434
# Default LLM Provider and Model
# These are used when no explicit provider/model is specified in API requests
# Can be changed via API: POST /api/llm/preferences
DEFAULT_LLM_PROVIDER=ollama
DEFAULT_LLM_MODEL=llama3.2
# Available providers: ollama, ollama-local, ollama-network, openai, anthropic

View File

@@ -434,8 +434,9 @@ Returns: { title, description, tips, example }
```
POST /api/llm/chat
Body: { message: string, session_id?: string, context?: string }
Body: { message: string, session_id?: string, context?: string, provider?: string, model?: string }
Returns: { message: string, success: boolean }
Note: If provider/model not specified, uses default preferences
GET /api/llm/autocomplete?partial_text=...&context_type=...
Returns: { suggestions: [...] }
@@ -443,8 +444,21 @@ Returns: { suggestions: [...] }
POST /api/llm/explain
Body: { item: string, item_type?: string, context?: {...} }
Returns: { explanation: string, item_type: string }
GET /api/llm/preferences
Returns: { current: { provider: string, model: string }, available_providers: [...] }
POST /api/llm/preferences
Body: { provider: string, model: string }
Returns: { status: string, provider: string, model: string, message: string }
```
**LLM Provider Selection:**
- Set default LLM provider and model via environment variables: `DEFAULT_LLM_PROVIDER`, `DEFAULT_LLM_MODEL`
- Change defaults at runtime via `/api/llm/preferences` endpoint
- Override per-request by specifying `provider` and `model` in request body
- Available providers: `ollama`, `ollama-local`, `ollama-network`, `openai`, `anthropic`
### Config Validation
```

View File

@@ -29,6 +29,8 @@ services:
environment:
- LLM_ROUTER_URL=http://strikepackage-llm-router:8000
- KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002
- DEFAULT_LLM_PROVIDER=${DEFAULT_LLM_PROVIDER:-ollama}
- DEFAULT_LLM_MODEL=${DEFAULT_LLM_MODEL:-llama3.2}
depends_on:
- llm-router
- kali-executor

View File

@@ -32,10 +32,18 @@ app.add_middleware(
LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000")
KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002")
# Default LLM Configuration (can be overridden via environment or API)
DEFAULT_LLM_PROVIDER = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
DEFAULT_LLM_MODEL = os.getenv("DEFAULT_LLM_MODEL", "llama3.2")
# In-memory storage (use Redis in production)
tasks: Dict[str, Any] = {}
sessions: Dict[str, Dict] = {}
scan_results: Dict[str, Any] = {}
llm_preferences: Dict[str, Any] = {
"provider": DEFAULT_LLM_PROVIDER,
"model": DEFAULT_LLM_MODEL
}
# ============== Models ==============
@@ -50,22 +58,27 @@ class ChatRequest(BaseModel):
message: str
session_id: Optional[str] = None
context: Optional[str] = None
provider: str = "ollama"
model: str = "llama3.2"
provider: Optional[str] = None # None means use default
model: Optional[str] = None # None means use default
class PhaseChatRequest(BaseModel):
message: str
phase: str
provider: str = "ollama"
model: str = "llama3.2"
provider: Optional[str] = None # None means use default
model: Optional[str] = None # None means use default
findings: List[Dict[str, Any]] = []
class AttackChainRequest(BaseModel):
findings: List[Dict[str, Any]]
provider: str = "ollama"
model: str = "llama3.2"
provider: Optional[str] = None # None means use default
model: Optional[str] = None # None means use default
class LLMPreferencesRequest(BaseModel):
provider: str
model: str
class CommandRequest(BaseModel):
@@ -335,7 +348,7 @@ async def health_check():
@app.post("/chat")
async def security_chat(request: ChatRequest):
"""Chat with security-focused AI assistant"""
"""Chat with security-focused AI assistant - uses default LLM preferences if not specified"""
messages = [
{
"role": "system",
@@ -357,8 +370,8 @@ vulnerabilities and defenses."""
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": request.provider,
"model": request.model,
"provider": request.provider or llm_preferences["provider"],
"model": request.model or llm_preferences["model"],
"messages": messages,
"temperature": 0.7,
"max_tokens": 2048
@@ -376,7 +389,7 @@ vulnerabilities and defenses."""
@app.post("/chat/phase")
async def phase_aware_chat(request: PhaseChatRequest):
"""Phase-aware chat with context from current pentest phase"""
"""Phase-aware chat with context from current pentest phase - uses default LLM preferences if not specified"""
phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"])
# Build context from findings if available
@@ -400,8 +413,8 @@ async def phase_aware_chat(request: PhaseChatRequest):
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": request.provider,
"model": request.model,
"provider": request.provider or llm_preferences["provider"],
"model": request.model or llm_preferences["model"],
"messages": messages,
"temperature": 0.7,
"max_tokens": 2048
@@ -515,8 +528,8 @@ Only return valid JSON."""
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": request.provider,
"model": request.model,
"provider": request.provider or llm_preferences["provider"],
"model": request.model or llm_preferences["model"],
"messages": messages,
"temperature": 0.3,
"max_tokens": 2048
@@ -919,7 +932,7 @@ def parse_gobuster_output(output: str) -> Dict[str, Any]:
@app.post("/ai-scan")
async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks):
"""Use AI to determine and run appropriate scan."""
"""Use AI to determine and run appropriate scan - uses default LLM preferences if not specified"""
# Get AI suggestion
messages = [
{"role": "system", "content": SECURITY_PROMPTS["command_assist"]},
@@ -931,8 +944,8 @@ async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTas
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": request.provider,
"model": request.model,
"provider": request.provider or llm_preferences["provider"],
"model": request.model or llm_preferences["model"],
"messages": messages,
"temperature": 0.3,
"max_tokens": 1024
@@ -995,8 +1008,8 @@ async def run_analysis(task_id: str, request: SecurityAnalysisRequest):
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": "ollama",
"model": "llama3.2",
"provider": llm_preferences["provider"],
"model": llm_preferences["model"],
"messages": [
{"role": "system", "content": prompt},
{"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"}
@@ -1060,7 +1073,7 @@ async def list_tools():
@app.post("/suggest-command")
async def suggest_command(request: ChatRequest):
"""Get AI-suggested security commands based on context"""
"""Get AI-suggested security commands based on context - uses default LLM preferences if not specified"""
messages = [
{
"role": "system",
@@ -1083,8 +1096,8 @@ Only suggest commands for legitimate security testing purposes."""
response = await client.post(
f"{LLM_ROUTER_URL}/chat",
json={
"provider": request.provider,
"model": request.model,
"provider": request.provider or llm_preferences["provider"],
"model": request.model or llm_preferences["model"],
"messages": messages,
"temperature": 0.3,
"max_tokens": 1024
@@ -1231,18 +1244,18 @@ async def llm_chat_help(
message: str,
session_id: Optional[str] = None,
context: Optional[str] = None,
provider: str = "ollama",
model: str = "llama3.2"
provider: Optional[str] = None,
model: Optional[str] = None
):
"""LLM-powered chat help"""
"""LLM-powered chat help - uses default preferences if provider/model not specified"""
try:
from . import llm_help
result = await llm_help.chat_completion(
message=message,
session_id=session_id,
context=context,
provider=provider,
model=model
provider=provider or llm_preferences["provider"],
model=model or llm_preferences["model"]
)
return result
except Exception as e:
@@ -1398,6 +1411,80 @@ async def send_push_notification(
raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}")
# ============== LLM Preferences ==============
@app.get("/api/llm/preferences")
async def get_llm_preferences():
"""
Get current default LLM provider and model preferences.
Returns:
Dictionary with provider, model, and available options
"""
try:
# Get available providers from LLM router
async with httpx.AsyncClient() as client:
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
available_providers = response.json() if response.status_code == 200 else []
return {
"current": {
"provider": llm_preferences["provider"],
"model": llm_preferences["model"]
},
"available_providers": available_providers,
"description": "Current default LLM provider and model. These are used when no explicit provider/model is specified in API requests."
}
except Exception as e:
return {
"current": {
"provider": llm_preferences["provider"],
"model": llm_preferences["model"]
},
"available_providers": [],
"error": str(e)
}
@app.post("/api/llm/preferences")
async def set_llm_preferences(request: LLMPreferencesRequest):
"""
Set default LLM provider and model preferences.
Args:
request: LLMPreferencesRequest with provider and model
Returns:
Updated preferences
"""
# Validate provider is available
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
if response.status_code == 200:
available_providers = response.json()
provider_names = [p["name"] for p in available_providers]
if request.provider not in provider_names:
raise HTTPException(
status_code=400,
detail=f"Provider '{request.provider}' not available. Available: {provider_names}"
)
except httpx.ConnectError:
# LLM router not available, proceed anyway
pass
# Update preferences
llm_preferences["provider"] = request.provider
llm_preferences["model"] = request.model
return {
"status": "updated",
"provider": llm_preferences["provider"],
"model": llm_preferences["model"],
"message": f"Default LLM set to {request.provider}/{request.model}"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)