Add configurable default LLM provider and model preferences

Co-authored-by: mblanke <9078342+mblanke@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-03 17:39:37 +00:00
parent c4eaf1718a
commit 91b4697403
4 changed files with 138 additions and 28 deletions

View File

@@ -7,3 +7,10 @@ ANTHROPIC_API_KEY=
# Ollama Configuration # Ollama Configuration
OLLAMA_BASE_URL=http://ollama:11434 OLLAMA_BASE_URL=http://ollama:11434
# Default LLM Provider and Model
# These are used when no explicit provider/model is specified in API requests
# Can be changed via API: POST /api/llm/preferences
DEFAULT_LLM_PROVIDER=ollama
DEFAULT_LLM_MODEL=llama3.2
# Available providers: ollama, ollama-local, ollama-network, openai, anthropic

View File

@@ -434,8 +434,9 @@ Returns: { title, description, tips, example }
``` ```
POST /api/llm/chat POST /api/llm/chat
Body: { message: string, session_id?: string, context?: string } Body: { message: string, session_id?: string, context?: string, provider?: string, model?: string }
Returns: { message: string, success: boolean } Returns: { message: string, success: boolean }
Note: If provider/model not specified, uses default preferences
GET /api/llm/autocomplete?partial_text=...&context_type=... GET /api/llm/autocomplete?partial_text=...&context_type=...
Returns: { suggestions: [...] } Returns: { suggestions: [...] }
@@ -443,8 +444,21 @@ Returns: { suggestions: [...] }
POST /api/llm/explain POST /api/llm/explain
Body: { item: string, item_type?: string, context?: {...} } Body: { item: string, item_type?: string, context?: {...} }
Returns: { explanation: string, item_type: string } Returns: { explanation: string, item_type: string }
GET /api/llm/preferences
Returns: { current: { provider: string, model: string }, available_providers: [...] }
POST /api/llm/preferences
Body: { provider: string, model: string }
Returns: { status: string, provider: string, model: string, message: string }
``` ```
**LLM Provider Selection:**
- Set default LLM provider and model via environment variables: `DEFAULT_LLM_PROVIDER`, `DEFAULT_LLM_MODEL`
- Change defaults at runtime via `/api/llm/preferences` endpoint
- Override per-request by specifying `provider` and `model` in request body
- Available providers: `ollama`, `ollama-local`, `ollama-network`, `openai`, `anthropic`
### Config Validation ### Config Validation
``` ```

View File

@@ -29,6 +29,8 @@ services:
environment: environment:
- LLM_ROUTER_URL=http://strikepackage-llm-router:8000 - LLM_ROUTER_URL=http://strikepackage-llm-router:8000
- KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002 - KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002
- DEFAULT_LLM_PROVIDER=${DEFAULT_LLM_PROVIDER:-ollama}
- DEFAULT_LLM_MODEL=${DEFAULT_LLM_MODEL:-llama3.2}
depends_on: depends_on:
- llm-router - llm-router
- kali-executor - kali-executor

View File

@@ -32,10 +32,18 @@ app.add_middleware(
LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000") LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000")
KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002") KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002")
# Default LLM Configuration (can be overridden via environment or API)
DEFAULT_LLM_PROVIDER = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
DEFAULT_LLM_MODEL = os.getenv("DEFAULT_LLM_MODEL", "llama3.2")
# In-memory storage (use Redis in production) # In-memory storage (use Redis in production)
tasks: Dict[str, Any] = {} tasks: Dict[str, Any] = {}
sessions: Dict[str, Dict] = {} sessions: Dict[str, Dict] = {}
scan_results: Dict[str, Any] = {} scan_results: Dict[str, Any] = {}
llm_preferences: Dict[str, Any] = {
"provider": DEFAULT_LLM_PROVIDER,
"model": DEFAULT_LLM_MODEL
}
# ============== Models ============== # ============== Models ==============
@@ -50,22 +58,27 @@ class ChatRequest(BaseModel):
message: str message: str
session_id: Optional[str] = None session_id: Optional[str] = None
context: Optional[str] = None context: Optional[str] = None
provider: str = "ollama" provider: Optional[str] = None # None means use default
model: str = "llama3.2" model: Optional[str] = None # None means use default
class PhaseChatRequest(BaseModel): class PhaseChatRequest(BaseModel):
message: str message: str
phase: str phase: str
provider: str = "ollama" provider: Optional[str] = None # None means use default
model: str = "llama3.2" model: Optional[str] = None # None means use default
findings: List[Dict[str, Any]] = [] findings: List[Dict[str, Any]] = []
class AttackChainRequest(BaseModel): class AttackChainRequest(BaseModel):
findings: List[Dict[str, Any]] findings: List[Dict[str, Any]]
provider: str = "ollama" provider: Optional[str] = None # None means use default
model: str = "llama3.2" model: Optional[str] = None # None means use default
class LLMPreferencesRequest(BaseModel):
provider: str
model: str
class CommandRequest(BaseModel): class CommandRequest(BaseModel):
@@ -335,7 +348,7 @@ async def health_check():
@app.post("/chat") @app.post("/chat")
async def security_chat(request: ChatRequest): async def security_chat(request: ChatRequest):
"""Chat with security-focused AI assistant""" """Chat with security-focused AI assistant - uses default LLM preferences if not specified"""
messages = [ messages = [
{ {
"role": "system", "role": "system",
@@ -357,8 +370,8 @@ vulnerabilities and defenses."""
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": request.provider, "provider": request.provider or llm_preferences["provider"],
"model": request.model, "model": request.model or llm_preferences["model"],
"messages": messages, "messages": messages,
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 2048 "max_tokens": 2048
@@ -376,7 +389,7 @@ vulnerabilities and defenses."""
@app.post("/chat/phase") @app.post("/chat/phase")
async def phase_aware_chat(request: PhaseChatRequest): async def phase_aware_chat(request: PhaseChatRequest):
"""Phase-aware chat with context from current pentest phase""" """Phase-aware chat with context from current pentest phase - uses default LLM preferences if not specified"""
phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"]) phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"])
# Build context from findings if available # Build context from findings if available
@@ -400,8 +413,8 @@ async def phase_aware_chat(request: PhaseChatRequest):
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": request.provider, "provider": request.provider or llm_preferences["provider"],
"model": request.model, "model": request.model or llm_preferences["model"],
"messages": messages, "messages": messages,
"temperature": 0.7, "temperature": 0.7,
"max_tokens": 2048 "max_tokens": 2048
@@ -515,8 +528,8 @@ Only return valid JSON."""
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": request.provider, "provider": request.provider or llm_preferences["provider"],
"model": request.model, "model": request.model or llm_preferences["model"],
"messages": messages, "messages": messages,
"temperature": 0.3, "temperature": 0.3,
"max_tokens": 2048 "max_tokens": 2048
@@ -919,7 +932,7 @@ def parse_gobuster_output(output: str) -> Dict[str, Any]:
@app.post("/ai-scan") @app.post("/ai-scan")
async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks): async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks):
"""Use AI to determine and run appropriate scan.""" """Use AI to determine and run appropriate scan - uses default LLM preferences if not specified"""
# Get AI suggestion # Get AI suggestion
messages = [ messages = [
{"role": "system", "content": SECURITY_PROMPTS["command_assist"]}, {"role": "system", "content": SECURITY_PROMPTS["command_assist"]},
@@ -931,8 +944,8 @@ async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTas
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": request.provider, "provider": request.provider or llm_preferences["provider"],
"model": request.model, "model": request.model or llm_preferences["model"],
"messages": messages, "messages": messages,
"temperature": 0.3, "temperature": 0.3,
"max_tokens": 1024 "max_tokens": 1024
@@ -995,8 +1008,8 @@ async def run_analysis(task_id: str, request: SecurityAnalysisRequest):
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": "ollama", "provider": llm_preferences["provider"],
"model": "llama3.2", "model": llm_preferences["model"],
"messages": [ "messages": [
{"role": "system", "content": prompt}, {"role": "system", "content": prompt},
{"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"} {"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"}
@@ -1060,7 +1073,7 @@ async def list_tools():
@app.post("/suggest-command") @app.post("/suggest-command")
async def suggest_command(request: ChatRequest): async def suggest_command(request: ChatRequest):
"""Get AI-suggested security commands based on context""" """Get AI-suggested security commands based on context - uses default LLM preferences if not specified"""
messages = [ messages = [
{ {
"role": "system", "role": "system",
@@ -1083,8 +1096,8 @@ Only suggest commands for legitimate security testing purposes."""
response = await client.post( response = await client.post(
f"{LLM_ROUTER_URL}/chat", f"{LLM_ROUTER_URL}/chat",
json={ json={
"provider": request.provider, "provider": request.provider or llm_preferences["provider"],
"model": request.model, "model": request.model or llm_preferences["model"],
"messages": messages, "messages": messages,
"temperature": 0.3, "temperature": 0.3,
"max_tokens": 1024 "max_tokens": 1024
@@ -1231,18 +1244,18 @@ async def llm_chat_help(
message: str, message: str,
session_id: Optional[str] = None, session_id: Optional[str] = None,
context: Optional[str] = None, context: Optional[str] = None,
provider: str = "ollama", provider: Optional[str] = None,
model: str = "llama3.2" model: Optional[str] = None
): ):
"""LLM-powered chat help""" """LLM-powered chat help - uses default preferences if provider/model not specified"""
try: try:
from . import llm_help from . import llm_help
result = await llm_help.chat_completion( result = await llm_help.chat_completion(
message=message, message=message,
session_id=session_id, session_id=session_id,
context=context, context=context,
provider=provider, provider=provider or llm_preferences["provider"],
model=model model=model or llm_preferences["model"]
) )
return result return result
except Exception as e: except Exception as e:
@@ -1398,6 +1411,80 @@ async def send_push_notification(
raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}") raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}")
# ============== LLM Preferences ==============
@app.get("/api/llm/preferences")
async def get_llm_preferences():
"""
Get current default LLM provider and model preferences.
Returns:
Dictionary with provider, model, and available options
"""
try:
# Get available providers from LLM router
async with httpx.AsyncClient() as client:
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
available_providers = response.json() if response.status_code == 200 else []
return {
"current": {
"provider": llm_preferences["provider"],
"model": llm_preferences["model"]
},
"available_providers": available_providers,
"description": "Current default LLM provider and model. These are used when no explicit provider/model is specified in API requests."
}
except Exception as e:
return {
"current": {
"provider": llm_preferences["provider"],
"model": llm_preferences["model"]
},
"available_providers": [],
"error": str(e)
}
@app.post("/api/llm/preferences")
async def set_llm_preferences(request: LLMPreferencesRequest):
"""
Set default LLM provider and model preferences.
Args:
request: LLMPreferencesRequest with provider and model
Returns:
Updated preferences
"""
# Validate provider is available
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
if response.status_code == 200:
available_providers = response.json()
provider_names = [p["name"] for p in available_providers]
if request.provider not in provider_names:
raise HTTPException(
status_code=400,
detail=f"Provider '{request.provider}' not available. Available: {provider_names}"
)
except httpx.ConnectError:
# LLM router not available, proceed anyway
pass
# Update preferences
llm_preferences["provider"] = request.provider
llm_preferences["model"] = request.model
return {
"status": "updated",
"provider": llm_preferences["provider"],
"model": llm_preferences["model"],
"message": f"Default LLM set to {request.provider}/{request.model}"
}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001) uvicorn.run(app, host="0.0.0.0", port=8001)