mirror of
https://github.com/mblanke/StrikePackageGPT.git
synced 2026-03-01 14:20:21 -05:00
Add configurable default LLM provider and model preferences
Co-authored-by: mblanke <9078342+mblanke@users.noreply.github.com>
This commit is contained in:
@@ -7,3 +7,10 @@ ANTHROPIC_API_KEY=
|
||||
|
||||
# Ollama Configuration
|
||||
OLLAMA_BASE_URL=http://ollama:11434
|
||||
|
||||
# Default LLM Provider and Model
|
||||
# These are used when no explicit provider/model is specified in API requests
|
||||
# Can be changed via API: POST /api/llm/preferences
|
||||
DEFAULT_LLM_PROVIDER=ollama
|
||||
DEFAULT_LLM_MODEL=llama3.2
|
||||
# Available providers: ollama, ollama-local, ollama-network, openai, anthropic
|
||||
|
||||
16
FEATURES.md
16
FEATURES.md
@@ -434,8 +434,9 @@ Returns: { title, description, tips, example }
|
||||
|
||||
```
|
||||
POST /api/llm/chat
|
||||
Body: { message: string, session_id?: string, context?: string }
|
||||
Body: { message: string, session_id?: string, context?: string, provider?: string, model?: string }
|
||||
Returns: { message: string, success: boolean }
|
||||
Note: If provider/model not specified, uses default preferences
|
||||
|
||||
GET /api/llm/autocomplete?partial_text=...&context_type=...
|
||||
Returns: { suggestions: [...] }
|
||||
@@ -443,8 +444,21 @@ Returns: { suggestions: [...] }
|
||||
POST /api/llm/explain
|
||||
Body: { item: string, item_type?: string, context?: {...} }
|
||||
Returns: { explanation: string, item_type: string }
|
||||
|
||||
GET /api/llm/preferences
|
||||
Returns: { current: { provider: string, model: string }, available_providers: [...] }
|
||||
|
||||
POST /api/llm/preferences
|
||||
Body: { provider: string, model: string }
|
||||
Returns: { status: string, provider: string, model: string, message: string }
|
||||
```
|
||||
|
||||
**LLM Provider Selection:**
|
||||
- Set default LLM provider and model via environment variables: `DEFAULT_LLM_PROVIDER`, `DEFAULT_LLM_MODEL`
|
||||
- Change defaults at runtime via `/api/llm/preferences` endpoint
|
||||
- Override per-request by specifying `provider` and `model` in request body
|
||||
- Available providers: `ollama`, `ollama-local`, `ollama-network`, `openai`, `anthropic`
|
||||
|
||||
### Config Validation
|
||||
|
||||
```
|
||||
|
||||
@@ -29,6 +29,8 @@ services:
|
||||
environment:
|
||||
- LLM_ROUTER_URL=http://strikepackage-llm-router:8000
|
||||
- KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002
|
||||
- DEFAULT_LLM_PROVIDER=${DEFAULT_LLM_PROVIDER:-ollama}
|
||||
- DEFAULT_LLM_MODEL=${DEFAULT_LLM_MODEL:-llama3.2}
|
||||
depends_on:
|
||||
- llm-router
|
||||
- kali-executor
|
||||
|
||||
@@ -32,10 +32,18 @@ app.add_middleware(
|
||||
LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000")
|
||||
KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002")
|
||||
|
||||
# Default LLM Configuration (can be overridden via environment or API)
|
||||
DEFAULT_LLM_PROVIDER = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
|
||||
DEFAULT_LLM_MODEL = os.getenv("DEFAULT_LLM_MODEL", "llama3.2")
|
||||
|
||||
# In-memory storage (use Redis in production)
|
||||
tasks: Dict[str, Any] = {}
|
||||
sessions: Dict[str, Dict] = {}
|
||||
scan_results: Dict[str, Any] = {}
|
||||
llm_preferences: Dict[str, Any] = {
|
||||
"provider": DEFAULT_LLM_PROVIDER,
|
||||
"model": DEFAULT_LLM_MODEL
|
||||
}
|
||||
|
||||
|
||||
# ============== Models ==============
|
||||
@@ -50,22 +58,27 @@ class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
provider: str = "ollama"
|
||||
model: str = "llama3.2"
|
||||
provider: Optional[str] = None # None means use default
|
||||
model: Optional[str] = None # None means use default
|
||||
|
||||
|
||||
class PhaseChatRequest(BaseModel):
|
||||
message: str
|
||||
phase: str
|
||||
provider: str = "ollama"
|
||||
model: str = "llama3.2"
|
||||
provider: Optional[str] = None # None means use default
|
||||
model: Optional[str] = None # None means use default
|
||||
findings: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
class AttackChainRequest(BaseModel):
|
||||
findings: List[Dict[str, Any]]
|
||||
provider: str = "ollama"
|
||||
model: str = "llama3.2"
|
||||
provider: Optional[str] = None # None means use default
|
||||
model: Optional[str] = None # None means use default
|
||||
|
||||
|
||||
class LLMPreferencesRequest(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
@@ -335,7 +348,7 @@ async def health_check():
|
||||
|
||||
@app.post("/chat")
|
||||
async def security_chat(request: ChatRequest):
|
||||
"""Chat with security-focused AI assistant"""
|
||||
"""Chat with security-focused AI assistant - uses default LLM preferences if not specified"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -357,8 +370,8 @@ vulnerabilities and defenses."""
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": request.provider,
|
||||
"model": request.model,
|
||||
"provider": request.provider or llm_preferences["provider"],
|
||||
"model": request.model or llm_preferences["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2048
|
||||
@@ -376,7 +389,7 @@ vulnerabilities and defenses."""
|
||||
|
||||
@app.post("/chat/phase")
|
||||
async def phase_aware_chat(request: PhaseChatRequest):
|
||||
"""Phase-aware chat with context from current pentest phase"""
|
||||
"""Phase-aware chat with context from current pentest phase - uses default LLM preferences if not specified"""
|
||||
phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"])
|
||||
|
||||
# Build context from findings if available
|
||||
@@ -400,8 +413,8 @@ async def phase_aware_chat(request: PhaseChatRequest):
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": request.provider,
|
||||
"model": request.model,
|
||||
"provider": request.provider or llm_preferences["provider"],
|
||||
"model": request.model or llm_preferences["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2048
|
||||
@@ -515,8 +528,8 @@ Only return valid JSON."""
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": request.provider,
|
||||
"model": request.model,
|
||||
"provider": request.provider or llm_preferences["provider"],
|
||||
"model": request.model or llm_preferences["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 2048
|
||||
@@ -919,7 +932,7 @@ def parse_gobuster_output(output: str) -> Dict[str, Any]:
|
||||
|
||||
@app.post("/ai-scan")
|
||||
async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks):
|
||||
"""Use AI to determine and run appropriate scan."""
|
||||
"""Use AI to determine and run appropriate scan - uses default LLM preferences if not specified"""
|
||||
# Get AI suggestion
|
||||
messages = [
|
||||
{"role": "system", "content": SECURITY_PROMPTS["command_assist"]},
|
||||
@@ -931,8 +944,8 @@ async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTas
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": request.provider,
|
||||
"model": request.model,
|
||||
"provider": request.provider or llm_preferences["provider"],
|
||||
"model": request.model or llm_preferences["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1024
|
||||
@@ -995,8 +1008,8 @@ async def run_analysis(task_id: str, request: SecurityAnalysisRequest):
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": "ollama",
|
||||
"model": "llama3.2",
|
||||
"provider": llm_preferences["provider"],
|
||||
"model": llm_preferences["model"],
|
||||
"messages": [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"}
|
||||
@@ -1060,7 +1073,7 @@ async def list_tools():
|
||||
|
||||
@app.post("/suggest-command")
|
||||
async def suggest_command(request: ChatRequest):
|
||||
"""Get AI-suggested security commands based on context"""
|
||||
"""Get AI-suggested security commands based on context - uses default LLM preferences if not specified"""
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -1083,8 +1096,8 @@ Only suggest commands for legitimate security testing purposes."""
|
||||
response = await client.post(
|
||||
f"{LLM_ROUTER_URL}/chat",
|
||||
json={
|
||||
"provider": request.provider,
|
||||
"model": request.model,
|
||||
"provider": request.provider or llm_preferences["provider"],
|
||||
"model": request.model or llm_preferences["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 1024
|
||||
@@ -1231,18 +1244,18 @@ async def llm_chat_help(
|
||||
message: str,
|
||||
session_id: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
provider: str = "ollama",
|
||||
model: str = "llama3.2"
|
||||
provider: Optional[str] = None,
|
||||
model: Optional[str] = None
|
||||
):
|
||||
"""LLM-powered chat help"""
|
||||
"""LLM-powered chat help - uses default preferences if provider/model not specified"""
|
||||
try:
|
||||
from . import llm_help
|
||||
result = await llm_help.chat_completion(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
context=context,
|
||||
provider=provider,
|
||||
model=model
|
||||
provider=provider or llm_preferences["provider"],
|
||||
model=model or llm_preferences["model"]
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -1398,6 +1411,80 @@ async def send_push_notification(
|
||||
raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}")
|
||||
|
||||
|
||||
# ============== LLM Preferences ==============
|
||||
|
||||
@app.get("/api/llm/preferences")
|
||||
async def get_llm_preferences():
|
||||
"""
|
||||
Get current default LLM provider and model preferences.
|
||||
|
||||
Returns:
|
||||
Dictionary with provider, model, and available options
|
||||
"""
|
||||
try:
|
||||
# Get available providers from LLM router
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
|
||||
available_providers = response.json() if response.status_code == 200 else []
|
||||
|
||||
return {
|
||||
"current": {
|
||||
"provider": llm_preferences["provider"],
|
||||
"model": llm_preferences["model"]
|
||||
},
|
||||
"available_providers": available_providers,
|
||||
"description": "Current default LLM provider and model. These are used when no explicit provider/model is specified in API requests."
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"current": {
|
||||
"provider": llm_preferences["provider"],
|
||||
"model": llm_preferences["model"]
|
||||
},
|
||||
"available_providers": [],
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/llm/preferences")
|
||||
async def set_llm_preferences(request: LLMPreferencesRequest):
|
||||
"""
|
||||
Set default LLM provider and model preferences.
|
||||
|
||||
Args:
|
||||
request: LLMPreferencesRequest with provider and model
|
||||
|
||||
Returns:
|
||||
Updated preferences
|
||||
"""
|
||||
# Validate provider is available
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
|
||||
if response.status_code == 200:
|
||||
available_providers = response.json()
|
||||
provider_names = [p["name"] for p in available_providers]
|
||||
if request.provider not in provider_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider '{request.provider}' not available. Available: {provider_names}"
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
# LLM router not available, proceed anyway
|
||||
pass
|
||||
|
||||
# Update preferences
|
||||
llm_preferences["provider"] = request.provider
|
||||
llm_preferences["model"] = request.model
|
||||
|
||||
return {
|
||||
"status": "updated",
|
||||
"provider": llm_preferences["provider"],
|
||||
"model": llm_preferences["model"],
|
||||
"message": f"Default LLM set to {request.provider}/{request.model}"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
Reference in New Issue
Block a user