mirror of
https://github.com/mblanke/StrikePackageGPT.git
synced 2026-03-01 14:20:21 -05:00
Add configurable default LLM provider and model preferences
Co-authored-by: mblanke <9078342+mblanke@users.noreply.github.com>
This commit is contained in:
@@ -7,3 +7,10 @@ ANTHROPIC_API_KEY=
|
|||||||
|
|
||||||
# Ollama Configuration
|
# Ollama Configuration
|
||||||
OLLAMA_BASE_URL=http://ollama:11434
|
OLLAMA_BASE_URL=http://ollama:11434
|
||||||
|
|
||||||
|
# Default LLM Provider and Model
|
||||||
|
# These are used when no explicit provider/model is specified in API requests
|
||||||
|
# Can be changed via API: POST /api/llm/preferences
|
||||||
|
DEFAULT_LLM_PROVIDER=ollama
|
||||||
|
DEFAULT_LLM_MODEL=llama3.2
|
||||||
|
# Available providers: ollama, ollama-local, ollama-network, openai, anthropic
|
||||||
|
|||||||
16
FEATURES.md
16
FEATURES.md
@@ -434,8 +434,9 @@ Returns: { title, description, tips, example }
|
|||||||
|
|
||||||
```
|
```
|
||||||
POST /api/llm/chat
|
POST /api/llm/chat
|
||||||
Body: { message: string, session_id?: string, context?: string }
|
Body: { message: string, session_id?: string, context?: string, provider?: string, model?: string }
|
||||||
Returns: { message: string, success: boolean }
|
Returns: { message: string, success: boolean }
|
||||||
|
Note: If provider/model not specified, uses default preferences
|
||||||
|
|
||||||
GET /api/llm/autocomplete?partial_text=...&context_type=...
|
GET /api/llm/autocomplete?partial_text=...&context_type=...
|
||||||
Returns: { suggestions: [...] }
|
Returns: { suggestions: [...] }
|
||||||
@@ -443,8 +444,21 @@ Returns: { suggestions: [...] }
|
|||||||
POST /api/llm/explain
|
POST /api/llm/explain
|
||||||
Body: { item: string, item_type?: string, context?: {...} }
|
Body: { item: string, item_type?: string, context?: {...} }
|
||||||
Returns: { explanation: string, item_type: string }
|
Returns: { explanation: string, item_type: string }
|
||||||
|
|
||||||
|
GET /api/llm/preferences
|
||||||
|
Returns: { current: { provider: string, model: string }, available_providers: [...] }
|
||||||
|
|
||||||
|
POST /api/llm/preferences
|
||||||
|
Body: { provider: string, model: string }
|
||||||
|
Returns: { status: string, provider: string, model: string, message: string }
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**LLM Provider Selection:**
|
||||||
|
- Set default LLM provider and model via environment variables: `DEFAULT_LLM_PROVIDER`, `DEFAULT_LLM_MODEL`
|
||||||
|
- Change defaults at runtime via `/api/llm/preferences` endpoint
|
||||||
|
- Override per-request by specifying `provider` and `model` in request body
|
||||||
|
- Available providers: `ollama`, `ollama-local`, `ollama-network`, `openai`, `anthropic`
|
||||||
|
|
||||||
### Config Validation
|
### Config Validation
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- LLM_ROUTER_URL=http://strikepackage-llm-router:8000
|
- LLM_ROUTER_URL=http://strikepackage-llm-router:8000
|
||||||
- KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002
|
- KALI_EXECUTOR_URL=http://strikepackage-kali-executor:8002
|
||||||
|
- DEFAULT_LLM_PROVIDER=${DEFAULT_LLM_PROVIDER:-ollama}
|
||||||
|
- DEFAULT_LLM_MODEL=${DEFAULT_LLM_MODEL:-llama3.2}
|
||||||
depends_on:
|
depends_on:
|
||||||
- llm-router
|
- llm-router
|
||||||
- kali-executor
|
- kali-executor
|
||||||
|
|||||||
@@ -32,10 +32,18 @@ app.add_middleware(
|
|||||||
LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000")
|
LLM_ROUTER_URL = os.getenv("LLM_ROUTER_URL", "http://strikepackage-llm-router:8000")
|
||||||
KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002")
|
KALI_EXECUTOR_URL = os.getenv("KALI_EXECUTOR_URL", "http://strikepackage-kali-executor:8002")
|
||||||
|
|
||||||
|
# Default LLM Configuration (can be overridden via environment or API)
|
||||||
|
DEFAULT_LLM_PROVIDER = os.getenv("DEFAULT_LLM_PROVIDER", "ollama")
|
||||||
|
DEFAULT_LLM_MODEL = os.getenv("DEFAULT_LLM_MODEL", "llama3.2")
|
||||||
|
|
||||||
# In-memory storage (use Redis in production)
|
# In-memory storage (use Redis in production)
|
||||||
tasks: Dict[str, Any] = {}
|
tasks: Dict[str, Any] = {}
|
||||||
sessions: Dict[str, Dict] = {}
|
sessions: Dict[str, Dict] = {}
|
||||||
scan_results: Dict[str, Any] = {}
|
scan_results: Dict[str, Any] = {}
|
||||||
|
llm_preferences: Dict[str, Any] = {
|
||||||
|
"provider": DEFAULT_LLM_PROVIDER,
|
||||||
|
"model": DEFAULT_LLM_MODEL
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ============== Models ==============
|
# ============== Models ==============
|
||||||
@@ -50,22 +58,27 @@ class ChatRequest(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
context: Optional[str] = None
|
context: Optional[str] = None
|
||||||
provider: str = "ollama"
|
provider: Optional[str] = None # None means use default
|
||||||
model: str = "llama3.2"
|
model: Optional[str] = None # None means use default
|
||||||
|
|
||||||
|
|
||||||
class PhaseChatRequest(BaseModel):
|
class PhaseChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
phase: str
|
phase: str
|
||||||
provider: str = "ollama"
|
provider: Optional[str] = None # None means use default
|
||||||
model: str = "llama3.2"
|
model: Optional[str] = None # None means use default
|
||||||
findings: List[Dict[str, Any]] = []
|
findings: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
||||||
class AttackChainRequest(BaseModel):
|
class AttackChainRequest(BaseModel):
|
||||||
findings: List[Dict[str, Any]]
|
findings: List[Dict[str, Any]]
|
||||||
provider: str = "ollama"
|
provider: Optional[str] = None # None means use default
|
||||||
model: str = "llama3.2"
|
model: Optional[str] = None # None means use default
|
||||||
|
|
||||||
|
|
||||||
|
class LLMPreferencesRequest(BaseModel):
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
class CommandRequest(BaseModel):
|
class CommandRequest(BaseModel):
|
||||||
@@ -335,7 +348,7 @@ async def health_check():
|
|||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
async def security_chat(request: ChatRequest):
|
async def security_chat(request: ChatRequest):
|
||||||
"""Chat with security-focused AI assistant"""
|
"""Chat with security-focused AI assistant - uses default LLM preferences if not specified"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -357,8 +370,8 @@ vulnerabilities and defenses."""
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": request.provider,
|
"provider": request.provider or llm_preferences["provider"],
|
||||||
"model": request.model,
|
"model": request.model or llm_preferences["model"],
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 2048
|
"max_tokens": 2048
|
||||||
@@ -376,7 +389,7 @@ vulnerabilities and defenses."""
|
|||||||
|
|
||||||
@app.post("/chat/phase")
|
@app.post("/chat/phase")
|
||||||
async def phase_aware_chat(request: PhaseChatRequest):
|
async def phase_aware_chat(request: PhaseChatRequest):
|
||||||
"""Phase-aware chat with context from current pentest phase"""
|
"""Phase-aware chat with context from current pentest phase - uses default LLM preferences if not specified"""
|
||||||
phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"])
|
phase_prompt = PHASE_PROMPTS.get(request.phase, PHASE_PROMPTS["recon"])
|
||||||
|
|
||||||
# Build context from findings if available
|
# Build context from findings if available
|
||||||
@@ -400,8 +413,8 @@ async def phase_aware_chat(request: PhaseChatRequest):
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": request.provider,
|
"provider": request.provider or llm_preferences["provider"],
|
||||||
"model": request.model,
|
"model": request.model or llm_preferences["model"],
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"max_tokens": 2048
|
"max_tokens": 2048
|
||||||
@@ -515,8 +528,8 @@ Only return valid JSON."""
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": request.provider,
|
"provider": request.provider or llm_preferences["provider"],
|
||||||
"model": request.model,
|
"model": request.model or llm_preferences["model"],
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"max_tokens": 2048
|
"max_tokens": 2048
|
||||||
@@ -919,7 +932,7 @@ def parse_gobuster_output(output: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
@app.post("/ai-scan")
|
@app.post("/ai-scan")
|
||||||
async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks):
|
async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTasks):
|
||||||
"""Use AI to determine and run appropriate scan."""
|
"""Use AI to determine and run appropriate scan - uses default LLM preferences if not specified"""
|
||||||
# Get AI suggestion
|
# Get AI suggestion
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": SECURITY_PROMPTS["command_assist"]},
|
{"role": "system", "content": SECURITY_PROMPTS["command_assist"]},
|
||||||
@@ -931,8 +944,8 @@ async def ai_assisted_scan(request: ChatRequest, background_tasks: BackgroundTas
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": request.provider,
|
"provider": request.provider or llm_preferences["provider"],
|
||||||
"model": request.model,
|
"model": request.model or llm_preferences["model"],
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"max_tokens": 1024
|
"max_tokens": 1024
|
||||||
@@ -995,8 +1008,8 @@ async def run_analysis(task_id: str, request: SecurityAnalysisRequest):
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": "ollama",
|
"provider": llm_preferences["provider"],
|
||||||
"model": "llama3.2",
|
"model": llm_preferences["model"],
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": prompt},
|
{"role": "system", "content": prompt},
|
||||||
{"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"}
|
{"role": "user", "content": f"Analyze target: {request.target}\nOptions: {request.options}"}
|
||||||
@@ -1060,7 +1073,7 @@ async def list_tools():
|
|||||||
|
|
||||||
@app.post("/suggest-command")
|
@app.post("/suggest-command")
|
||||||
async def suggest_command(request: ChatRequest):
|
async def suggest_command(request: ChatRequest):
|
||||||
"""Get AI-suggested security commands based on context"""
|
"""Get AI-suggested security commands based on context - uses default LLM preferences if not specified"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -1083,8 +1096,8 @@ Only suggest commands for legitimate security testing purposes."""
|
|||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{LLM_ROUTER_URL}/chat",
|
f"{LLM_ROUTER_URL}/chat",
|
||||||
json={
|
json={
|
||||||
"provider": request.provider,
|
"provider": request.provider or llm_preferences["provider"],
|
||||||
"model": request.model,
|
"model": request.model or llm_preferences["model"],
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"max_tokens": 1024
|
"max_tokens": 1024
|
||||||
@@ -1231,18 +1244,18 @@ async def llm_chat_help(
|
|||||||
message: str,
|
message: str,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
provider: str = "ollama",
|
provider: Optional[str] = None,
|
||||||
model: str = "llama3.2"
|
model: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""LLM-powered chat help"""
|
"""LLM-powered chat help - uses default preferences if provider/model not specified"""
|
||||||
try:
|
try:
|
||||||
from . import llm_help
|
from . import llm_help
|
||||||
result = await llm_help.chat_completion(
|
result = await llm_help.chat_completion(
|
||||||
message=message,
|
message=message,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
context=context,
|
context=context,
|
||||||
provider=provider,
|
provider=provider or llm_preferences["provider"],
|
||||||
model=model
|
model=model or llm_preferences["model"]
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1398,6 +1411,80 @@ async def send_push_notification(
|
|||||||
raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Push notification error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ============== LLM Preferences ==============
|
||||||
|
|
||||||
|
@app.get("/api/llm/preferences")
|
||||||
|
async def get_llm_preferences():
|
||||||
|
"""
|
||||||
|
Get current default LLM provider and model preferences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with provider, model, and available options
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get available providers from LLM router
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
|
||||||
|
available_providers = response.json() if response.status_code == 200 else []
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current": {
|
||||||
|
"provider": llm_preferences["provider"],
|
||||||
|
"model": llm_preferences["model"]
|
||||||
|
},
|
||||||
|
"available_providers": available_providers,
|
||||||
|
"description": "Current default LLM provider and model. These are used when no explicit provider/model is specified in API requests."
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"current": {
|
||||||
|
"provider": llm_preferences["provider"],
|
||||||
|
"model": llm_preferences["model"]
|
||||||
|
},
|
||||||
|
"available_providers": [],
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/llm/preferences")
|
||||||
|
async def set_llm_preferences(request: LLMPreferencesRequest):
|
||||||
|
"""
|
||||||
|
Set default LLM provider and model preferences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: LLMPreferencesRequest with provider and model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated preferences
|
||||||
|
"""
|
||||||
|
# Validate provider is available
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(f"{LLM_ROUTER_URL}/providers", timeout=10.0)
|
||||||
|
if response.status_code == 200:
|
||||||
|
available_providers = response.json()
|
||||||
|
provider_names = [p["name"] for p in available_providers]
|
||||||
|
if request.provider not in provider_names:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Provider '{request.provider}' not available. Available: {provider_names}"
|
||||||
|
)
|
||||||
|
except httpx.ConnectError:
|
||||||
|
# LLM router not available, proceed anyway
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Update preferences
|
||||||
|
llm_preferences["provider"] = request.provider
|
||||||
|
llm_preferences["model"] = request.model
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "updated",
|
||||||
|
"provider": llm_preferences["provider"],
|
||||||
|
"model": llm_preferences["model"],
|
||||||
|
"message": f"Default LLM set to {request.provider}/{request.model}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||||
Reference in New Issue
Block a user