mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
408
backend/app/agents/core_v2.py
Normal file
408
backend/app/agents/core_v2.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Core ThreatHunt analyst-assist agent — v2.
|
||||
|
||||
Uses TaskRouter to select the right model/node for each query,
|
||||
real LLM providers (Ollama/OpenWebUI), and structured response parsing.
|
||||
Integrates SANS RAG context from Open WebUI.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.services.sans_rag import sans_rag
|
||||
from .router import TaskRouter, TaskType, RoutingDecision, task_router
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(..., description="Analyst question or request for guidance")
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type")
|
||||
host_identifier: Optional[str] = Field(None, description="Host name, IP, or identifier")
|
||||
data_summary: Optional[str] = Field(None, description="Brief description of data")
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages"
|
||||
)
|
||||
active_hypotheses: Optional[list[str]] = Field(
|
||||
default_factory=list, description="Active investigation hypotheses"
|
||||
)
|
||||
annotations_summary: Optional[str] = Field(
|
||||
None, description="Summary of analyst annotations"
|
||||
)
|
||||
enrichment_summary: Optional[str] = Field(
|
||||
None, description="Summary of enrichment results"
|
||||
)
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: Optional[str] = Field(None, description="Force a specific model")
|
||||
|
||||
|
||||
class Perspective(BaseModel):
|
||||
"""A single perspective from the debate agent."""
|
||||
role: str
|
||||
content: str
|
||||
model_used: str
|
||||
node_used: str
|
||||
latency_ms: int
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence (0-1)")
|
||||
suggested_pivots: list[str] = Field(default_factory=list)
|
||||
suggested_filters: list[str] = Field(default_factory=list)
|
||||
caveats: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
sans_references: list[str] = Field(
|
||||
default_factory=list, description="SANS course references"
|
||||
)
|
||||
model_used: str = Field(default="", description="Model that generated the response")
|
||||
node_used: str = Field(default="", description="Node that processed the request")
|
||||
latency_ms: int = Field(default=0, description="Total latency in ms")
|
||||
perspectives: Optional[list[Perspective]] = Field(
|
||||
None, description="Debate perspectives (only in debate mode)"
|
||||
)
|
||||
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────
|
||||
|
||||
SYSTEM_PROMPT = """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
You have access to 300GB of SANS cybersecurity course material for reference.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor and other forensic tools
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Reference relevant SANS methodologies and techniques when applicable
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You explain your reasoning transparently
|
||||
|
||||
RESPONSE FORMAT — you MUST respond with valid JSON:
|
||||
{
|
||||
"guidance": "Your main guidance text here",
|
||||
"confidence": 0.85,
|
||||
"suggested_pivots": ["Pivot 1", "Pivot 2"],
|
||||
"suggested_filters": ["filter expression 1", "filter expression 2"],
|
||||
"caveats": "Any assumptions or limitations",
|
||||
"reasoning": "How you arrived at this guidance",
|
||||
"sans_references": ["SANS SEC504: ...", "SANS FOR508: ..."]
|
||||
}
|
||||
|
||||
Respond ONLY with the JSON object. No markdown, no code fences, no extra text."""
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent backed by Wile + Roadrunner LLM cluster."""
|
||||
|
||||
def __init__(self, router: TaskRouter | None = None):
|
||||
self.router = router or task_router
|
||||
self.system_prompt = SYSTEM_PROMPT
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis."""
|
||||
start = time.monotonic()
|
||||
|
||||
if context.mode == "debate":
|
||||
return await self._debate_assist(context)
|
||||
|
||||
# Classify task and route
|
||||
task_type = self.router.classify_task(context.query)
|
||||
if context.mode == "deep":
|
||||
task_type = TaskType.DEEP_ANALYSIS
|
||||
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
logger.info(f"Routing: {decision.reason}")
|
||||
|
||||
# Enrich prompt with SANS RAG context
|
||||
prompt = self._build_prompt(context)
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
context.query,
|
||||
investigation_context=context.data_summary or "",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
result = await provider.chat(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
|
||||
raw_text = result.get("response", "")
|
||||
latency_ms = result.get("_latency_ms", 0)
|
||||
|
||||
# Parse structured response
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = decision.model
|
||||
response.node_used = decision.node.value
|
||||
response.latency_ms = latency_ms
|
||||
|
||||
total_ms = int((time.monotonic() - start) * 1000)
|
||||
logger.info(
|
||||
f"Agent assist: {context.query[:60]}... → "
|
||||
f"{decision.model} on {decision.node.value} "
|
||||
f"({total_ms}ms total, {latency_ms}ms LLM)"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def assist_stream(
|
||||
self,
|
||||
context: AgentContext,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream agent response tokens."""
|
||||
task_type = self.router.classify_task(context.query)
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OllamaProvider):
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
elif isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
async for token in provider.chat_stream(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
|
||||
async def _debate_assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Multi-perspective analysis using diverse models on Wile."""
|
||||
import asyncio
|
||||
|
||||
start = time.monotonic()
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
# Route each perspective to a different heavy model
|
||||
roles = {
|
||||
TaskType.DEBATE_PLANNER: (
|
||||
"Planner",
|
||||
"You are the Planner for a threat hunting investigation.\n"
|
||||
"Provide a structured investigation strategy. Reference SANS methodologies.\n"
|
||||
"Focus on: investigation steps, data sources to examine, MITRE ATT&CK mapping.\n"
|
||||
"Be specific to the data context provided.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_CRITIC: (
|
||||
"Critic",
|
||||
"You are the Critic for a threat hunting investigation.\n"
|
||||
"Identify risks, false positive scenarios, missing evidence, and assumptions.\n"
|
||||
"Reference SANS training on common analyst mistakes.\n"
|
||||
"Challenge the obvious interpretation.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_PRAGMATIST: (
|
||||
"Pragmatist",
|
||||
"You are the Pragmatist for a threat hunting investigation.\n"
|
||||
"Suggest the most actionable, efficient next steps.\n"
|
||||
"Reference SANS incident response playbooks.\n"
|
||||
"Focus on: quick wins, triage priorities, what to escalate.\n\n",
|
||||
),
|
||||
}
|
||||
|
||||
async def _call_perspective(task_type: TaskType, role_name: str, prefix: str):
|
||||
decision = self.router.route(task_type)
|
||||
provider = self.router.get_provider(decision)
|
||||
full_prompt = prefix + prompt
|
||||
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
|
||||
return Perspective(
|
||||
role=role_name,
|
||||
content=result.get("response", ""),
|
||||
model_used=decision.model,
|
||||
node_used=decision.node.value,
|
||||
latency_ms=result.get("_latency_ms", 0),
|
||||
)
|
||||
|
||||
# Run perspectives in parallel
|
||||
perspective_tasks = [
|
||||
_call_perspective(tt, name, prefix)
|
||||
for tt, (name, prefix) in roles.items()
|
||||
]
|
||||
perspectives = await asyncio.gather(*perspective_tasks)
|
||||
|
||||
# Judge merges the perspectives
|
||||
judge_prompt = (
|
||||
"You are the Judge. Merge these three threat hunting perspectives into "
|
||||
"ONE final advisory answer.\n\n"
|
||||
"Rules:\n"
|
||||
"- Advisory only — no execution\n"
|
||||
"- Clearly list risks and assumptions\n"
|
||||
"- Highlight where perspectives agree and disagree\n"
|
||||
"- Provide a unified recommendation\n"
|
||||
"- Reference SANS methodologies where relevant\n\n"
|
||||
)
|
||||
for p in perspectives:
|
||||
judge_prompt += f"=== {p.role} (via {p.model_used}) ===\n{p.content}\n\n"
|
||||
|
||||
judge_prompt += (
|
||||
f"\nOriginal analyst query:\n{context.query}\n\n"
|
||||
"Respond with the merged analysis in this JSON format:\n"
|
||||
'{"guidance": "...", "confidence": 0.85, "suggested_pivots": [...], '
|
||||
'"suggested_filters": [...], "caveats": "...", "reasoning": "...", '
|
||||
'"sans_references": [...]}'
|
||||
)
|
||||
|
||||
judge_decision = self.router.route(TaskType.DEBATE_JUDGE)
|
||||
judge_provider = self.router.get_provider(judge_decision)
|
||||
|
||||
if isinstance(judge_provider, OpenWebUIProvider):
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
else:
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
raw_text = judge_result.get("response", "")
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = judge_decision.model
|
||||
response.node_used = judge_decision.node.value
|
||||
response.latency_ms = int((time.monotonic() - start) * 1000)
|
||||
response.perspectives = list(perspectives)
|
||||
|
||||
return response
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt with all available context."""
|
||||
parts = [f"Analyst query: {context.query}"]
|
||||
|
||||
if context.dataset_name:
|
||||
parts.append(f"Dataset: {context.dataset_name}")
|
||||
if context.artifact_type:
|
||||
parts.append(f"Artifact type: {context.artifact_type}")
|
||||
if context.host_identifier:
|
||||
parts.append(f"Host: {context.host_identifier}")
|
||||
if context.data_summary:
|
||||
parts.append(f"Data summary: {context.data_summary}")
|
||||
if context.active_hypotheses:
|
||||
parts.append(f"Active hypotheses: {'; '.join(context.active_hypotheses)}")
|
||||
if context.annotations_summary:
|
||||
parts.append(f"Analyst annotations: {context.annotations_summary}")
|
||||
if context.enrichment_summary:
|
||||
parts.append(f"Enrichment data: {context.enrichment_summary}")
|
||||
if context.conversation_history:
|
||||
parts.append("\nRecent conversation:")
|
||||
for msg in context.conversation_history[-settings.AGENT_HISTORY_LENGTH:]:
|
||||
parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')[:500]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _parse_response(self, raw: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM output into structured AgentResponse.
|
||||
|
||||
Tries JSON extraction first, falls back to raw text with defaults.
|
||||
"""
|
||||
parsed = self._try_parse_json(raw)
|
||||
if parsed:
|
||||
return AgentResponse(
|
||||
guidance=parsed.get("guidance", raw),
|
||||
confidence=min(max(float(parsed.get("confidence", 0.7)), 0.0), 1.0),
|
||||
suggested_pivots=parsed.get("suggested_pivots", [])[:6],
|
||||
suggested_filters=parsed.get("suggested_filters", [])[:6],
|
||||
caveats=parsed.get("caveats"),
|
||||
reasoning=parsed.get("reasoning"),
|
||||
sans_references=parsed.get("sans_references", []),
|
||||
)
|
||||
|
||||
# Fallback: use raw text as guidance
|
||||
return AgentResponse(
|
||||
guidance=raw.strip() or "No guidance generated. Please try rephrasing your question.",
|
||||
confidence=0.5,
|
||||
suggested_pivots=[],
|
||||
suggested_filters=[],
|
||||
caveats="Response was not in structured format. Pivots and filters may be embedded in the guidance text.",
|
||||
reasoning=None,
|
||||
sans_references=[],
|
||||
)
|
||||
|
||||
def _try_parse_json(self, text: str) -> dict | None:
|
||||
"""Try to extract JSON from LLM output."""
|
||||
# Direct parse
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Extract from code fences
|
||||
patterns = [
|
||||
r"```json\s*(.*?)\s*```",
|
||||
r"```\s*(.*?)\s*```",
|
||||
r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1) if match.lastindex else match.group(0))
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
362
backend/app/agents/providers_v2.py
Normal file
362
backend/app/agents/providers_v2.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
|
||||
|
||||
Three providers:
|
||||
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
|
||||
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
|
||||
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from .registry import ModelEntry, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared HTTP client with reasonable timeouts
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None or _client.is_closed:
|
||||
_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
|
||||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def cleanup_client():
|
||||
global _client
|
||||
if _client and not _client.is_closed:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
|
||||
|
||||
def _ollama_url(node: Node) -> str:
|
||||
"""Get the Ollama base URL for a node."""
|
||||
if node == Node.WILE:
|
||||
return settings.wile_url
|
||||
elif node == Node.ROADRUNNER:
|
||||
return settings.roadrunner_url
|
||||
else:
|
||||
raise ValueError(f"No direct Ollama URL for node: {node}")
|
||||
|
||||
|
||||
# ── Ollama Provider ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class OllamaProvider:
|
||||
"""Direct Ollama API calls to Wile or Roadrunner."""
|
||||
|
||||
def __init__(self, model: str, node: Node):
|
||||
self.model = model
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
data["_latency_ms"] = latency_ms
|
||||
data["_node"] = self.node.value
|
||||
logger.info(
|
||||
f"Ollama [{self.node.value}] {self.model}: "
|
||||
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
|
||||
)
|
||||
return data
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
|
||||
raise
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via Ollama /api/chat."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
|
||||
data["_node"] = self.node.value
|
||||
return data
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from Ollama."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with client.stream(
|
||||
"POST", f"{self.base_url}/api/generate", json=payload
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Ping the Ollama node."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
|
||||
|
||||
|
||||
class OpenWebUIProvider:
|
||||
"""Calls to Open WebUI cluster at ai.guapo613.beer.
|
||||
|
||||
Uses the OpenAI-compatible /v1/chat/completions endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = ""):
|
||||
self.model = model or settings.DEFAULT_FAST_MODEL
|
||||
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via OpenAI-compatible endpoint."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Normalize to our format
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
result = {
|
||||
"response": content,
|
||||
"model": data.get("model", self.model),
|
||||
"_latency_ms": latency_ms,
|
||||
"_node": "cluster",
|
||||
"_usage": data.get("usage", {}),
|
||||
}
|
||||
logger.info(
|
||||
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
|
||||
)
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Convert prompt-style call to chat format."""
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
return await self.chat(messages, max_tokens, temperature)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from OpenWebUI."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
if token:
|
||||
yield token
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Open WebUI is reachable."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Embedding Provider ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Generate embeddings via Ollama /api/embeddings."""
|
||||
|
||||
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
|
||||
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
client = _get_client()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("embedding", [])
|
||||
|
||||
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
|
||||
"""Embed multiple texts with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _embed_one(t: str) -> list[float]:
|
||||
async with sem:
|
||||
return await self.embed(t)
|
||||
|
||||
return await asyncio.gather(*[_embed_one(t) for t in texts])
|
||||
|
||||
|
||||
# ── Health check for all nodes ────────────────────────────────────────
|
||||
|
||||
|
||||
async def check_all_nodes() -> dict:
|
||||
"""Check availability of all LLM nodes."""
|
||||
wile = OllamaProvider("", Node.WILE)
|
||||
roadrunner = OllamaProvider("", Node.ROADRUNNER)
|
||||
cluster = OpenWebUIProvider()
|
||||
|
||||
wile_ok, rr_ok, cl_ok = await asyncio.gather(
|
||||
wile.is_available(),
|
||||
roadrunner.is_available(),
|
||||
cluster.is_available(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"wile": {"available": wile_ok is True, "url": settings.wile_url},
|
||||
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
|
||||
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
|
||||
}
|
||||
161
backend/app/agents/registry.py
Normal file
161
backend/app/agents/registry.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Model registry — inventory of all Ollama models across Wile and Roadrunner.
|
||||
|
||||
Each model is tagged with capabilities (chat, code, vision, embedding) and
|
||||
performance tier (fast, medium, heavy) for the TaskRouter.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Capability(str, Enum):
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class Tier(str, Enum):
|
||||
FAST = "fast" # < 15B params — quick responses
|
||||
MEDIUM = "medium" # 15–40B params — balanced
|
||||
HEAVY = "heavy" # 40B+ params — deep analysis
|
||||
|
||||
|
||||
class Node(str, Enum):
|
||||
WILE = "wile"
|
||||
ROADRUNNER = "roadrunner"
|
||||
CLUSTER = "cluster" # Open WebUI balances across both
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelEntry:
|
||||
name: str
|
||||
node: Node
|
||||
capabilities: list[Capability]
|
||||
tier: Tier
|
||||
param_size: str = "" # e.g. "7b", "70b"
|
||||
notes: str = ""
|
||||
|
||||
|
||||
# ── Roadrunner (100.110.190.11) ──────────────────────────────────────
|
||||
|
||||
ROADRUNNER_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("qwen2.5:14b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "14b"),
|
||||
ModelEntry("mistral:7b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("mistral:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("phi3:medium", Node.ROADRUNNER, [Capability.CHAT], Tier.MEDIUM, "14b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("codestral:latest", Node.ROADRUNNER, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
ModelEntry("codellama:13b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "13b"),
|
||||
# Vision
|
||||
ModelEntry("llama3.2-vision:11b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "11b"),
|
||||
ModelEntry("minicpm-v:latest", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "8b"),
|
||||
ModelEntry("llava:13b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
ModelEntry("nomic-embed-text:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.1b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.ROADRUNNER, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
]
|
||||
|
||||
# ── Wile (100.110.190.12) ────────────────────────────────────────────
|
||||
|
||||
WILE_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("llama3:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("gemma2:27b", Node.WILE, [Capability.CHAT], Tier.MEDIUM, "27b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:32b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "32b"),
|
||||
ModelEntry("deepseek-coder:33b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "33b"),
|
||||
ModelEntry("codestral:latest", Node.WILE, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
# Vision
|
||||
ModelEntry("llava:13b", Node.WILE, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.WILE, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q5_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("mixtral:8x22b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "141b"),
|
||||
ModelEntry("qwen2:72b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "72b"),
|
||||
]
|
||||
|
||||
ALL_MODELS = ROADRUNNER_MODELS + WILE_MODELS
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry of all available models and their capabilities."""
|
||||
|
||||
def __init__(self, models: list[ModelEntry] | None = None):
|
||||
self.models = models or ALL_MODELS
|
||||
self._by_name: dict[str, list[ModelEntry]] = {}
|
||||
self._by_capability: dict[Capability, list[ModelEntry]] = {}
|
||||
self._by_node: dict[Node, list[ModelEntry]] = {}
|
||||
self._index()
|
||||
|
||||
def _index(self):
|
||||
for m in self.models:
|
||||
self._by_name.setdefault(m.name, []).append(m)
|
||||
for cap in m.capabilities:
|
||||
self._by_capability.setdefault(cap, []).append(m)
|
||||
self._by_node.setdefault(m.node, []).append(m)
|
||||
|
||||
def find(
|
||||
self,
|
||||
capability: Capability | None = None,
|
||||
tier: Tier | None = None,
|
||||
node: Node | None = None,
|
||||
) -> list[ModelEntry]:
|
||||
"""Find models matching all given criteria."""
|
||||
results = list(self.models)
|
||||
if capability:
|
||||
results = [m for m in results if capability in m.capabilities]
|
||||
if tier:
|
||||
results = [m for m in results if m.tier == tier]
|
||||
if node:
|
||||
results = [m for m in results if m.node == node]
|
||||
return results
|
||||
|
||||
def get_best(
|
||||
self,
|
||||
capability: Capability,
|
||||
prefer_tier: Tier | None = None,
|
||||
prefer_node: Node | None = None,
|
||||
) -> ModelEntry | None:
|
||||
"""Get the best model for a capability, with optional preference."""
|
||||
candidates = self.find(capability=capability, tier=prefer_tier, node=prefer_node)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability, tier=prefer_tier)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def list_nodes(self) -> list[Node]:
|
||||
return list(self._by_node.keys())
|
||||
|
||||
def list_models_on_node(self, node: Node) -> list[ModelEntry]:
|
||||
return self._by_node.get(node, [])
|
||||
|
||||
def to_dict(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": m.name,
|
||||
"node": m.node.value,
|
||||
"capabilities": [c.value for c in m.capabilities],
|
||||
"tier": m.tier.value,
|
||||
"param_size": m.param_size,
|
||||
}
|
||||
for m in self.models
|
||||
]
|
||||
|
||||
|
||||
# Singleton
|
||||
registry = ModelRegistry()
|
||||
183
backend/app/agents/router.py
Normal file
183
backend/app/agents/router.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Task router — auto-selects the right model + node for each task type.
|
||||
|
||||
Routes based on task characteristics:
|
||||
- Quick chat → fast models via cluster
|
||||
- Deep analysis → 70B+ models on Wile
|
||||
- Code/script analysis → code models (32b on Wile, 7b for quick)
|
||||
- Vision/image → vision models on Roadrunner
|
||||
- Embedding → embedding models on either node
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.config import settings
|
||||
from .registry import Capability, Tier, Node, ModelEntry, registry
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
QUICK_CHAT = "quick_chat"
|
||||
DEEP_ANALYSIS = "deep_analysis"
|
||||
CODE_ANALYSIS = "code_analysis"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
DEBATE_PLANNER = "debate_planner"
|
||||
DEBATE_CRITIC = "debate_critic"
|
||||
DEBATE_PRAGMATIST = "debate_pragmatist"
|
||||
DEBATE_JUDGE = "debate_judge"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of the routing decision."""
|
||||
model: str
|
||||
node: Node
|
||||
task_type: TaskType
|
||||
provider_type: str # "ollama" or "openwebui"
|
||||
reason: str
|
||||
|
||||
|
||||
class TaskRouter:
|
||||
"""Routes tasks to the appropriate model and node."""
|
||||
|
||||
# Default routing rules: task_type → (capability, preferred_tier, preferred_node)
|
||||
ROUTING_RULES: dict[TaskType, tuple[Capability, Tier | None, Node | None]] = {
|
||||
TaskType.QUICK_CHAT: (Capability.CHAT, Tier.FAST, None),
|
||||
TaskType.DEEP_ANALYSIS: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.CODE_ANALYSIS: (Capability.CODE, Tier.MEDIUM, Node.WILE),
|
||||
TaskType.VISION: (Capability.VISION, None, Node.ROADRUNNER),
|
||||
TaskType.EMBEDDING: (Capability.EMBEDDING, Tier.FAST, None),
|
||||
TaskType.DEBATE_PLANNER: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_CRITIC: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_PRAGMATIST: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_JUDGE: (Capability.CHAT, Tier.MEDIUM, Node.WILE),
|
||||
}
|
||||
|
||||
# Specific model overrides for debate roles (use diverse models for diversity of thought)
|
||||
DEBATE_MODEL_OVERRIDES: dict[TaskType, str] = {
|
||||
TaskType.DEBATE_PLANNER: "llama3.1:70b-instruct-q4_K_M",
|
||||
TaskType.DEBATE_CRITIC: "qwen2:72b-instruct",
|
||||
TaskType.DEBATE_PRAGMATIST: "mixtral:8x22b-instruct",
|
||||
TaskType.DEBATE_JUDGE: "gemma2:27b",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
|
||||
def route(self, task_type: TaskType, model_override: str | None = None) -> RoutingDecision:
|
||||
"""Decide which model and node to use for a task."""
|
||||
|
||||
# Explicit model override
|
||||
if model_override:
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_override:
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Explicit model override: {model_override}",
|
||||
)
|
||||
# Model not in registry — try via cluster
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"Override model {model_override} not in registry, routing to cluster",
|
||||
)
|
||||
|
||||
# Debate model overrides
|
||||
if task_type in self.DEBATE_MODEL_OVERRIDES:
|
||||
model_name = self.DEBATE_MODEL_OVERRIDES[task_type]
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_name:
|
||||
return RoutingDecision(
|
||||
model=model_name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Debate role {task_type.value} → {model_name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Standard routing
|
||||
cap, tier, node = self.ROUTING_RULES.get(
|
||||
task_type,
|
||||
(Capability.CHAT, Tier.FAST, None),
|
||||
)
|
||||
|
||||
entry = self.registry.get_best(cap, prefer_tier=tier, prefer_node=node)
|
||||
if entry:
|
||||
return RoutingDecision(
|
||||
model=entry.name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Auto-routed {task_type.value}: {cap.value}/{tier.value if tier else 'any'} → {entry.name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Fallback to cluster
|
||||
default_model = settings.DEFAULT_FAST_MODEL
|
||||
return RoutingDecision(
|
||||
model=default_model,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"No registry match, falling back to cluster with {default_model}",
|
||||
)
|
||||
|
||||
def get_provider(self, decision: RoutingDecision):
|
||||
"""Create the appropriate provider for a routing decision."""
|
||||
if decision.provider_type == "openwebui":
|
||||
return OpenWebUIProvider(model=decision.model)
|
||||
else:
|
||||
return OllamaProvider(model=decision.model, node=decision.node)
|
||||
|
||||
def get_embedding_provider(self, model: str | None = None, node: Node | None = None) -> EmbeddingProvider:
|
||||
"""Get an embedding provider."""
|
||||
return EmbeddingProvider(
|
||||
model=model or settings.DEFAULT_EMBEDDING_MODEL,
|
||||
node=node or Node.ROADRUNNER,
|
||||
)
|
||||
|
||||
def classify_task(self, query: str, has_image: bool = False) -> TaskType:
|
||||
"""Heuristic classification of query into task type.
|
||||
|
||||
In practice this could be enhanced by a classifier model, but
|
||||
keyword heuristics work well for routing.
|
||||
"""
|
||||
if has_image:
|
||||
return TaskType.VISION
|
||||
|
||||
q = query.lower()
|
||||
|
||||
# Code/script indicators
|
||||
code_indicators = [
|
||||
"deobfuscate", "decode", "powershell", "script", "base64",
|
||||
"command line", "cmdline", "commandline", "obfuscated",
|
||||
"malware", "shellcode", "vbs", "vbscript", "batch",
|
||||
"python script", "code review", "reverse engineer",
|
||||
]
|
||||
if any(ind in q for ind in code_indicators):
|
||||
return TaskType.CODE_ANALYSIS
|
||||
|
||||
# Deep analysis indicators
|
||||
deep_indicators = [
|
||||
"deep analysis", "detailed", "comprehensive", "thorough",
|
||||
"investigate", "root cause", "advanced", "explain in detail",
|
||||
"full analysis", "forensic",
|
||||
]
|
||||
if any(ind in q for ind in deep_indicators):
|
||||
return TaskType.DEEP_ANALYSIS
|
||||
|
||||
return TaskType.QUICK_CHAT
|
||||
|
||||
|
||||
# Singleton
|
||||
task_router = TaskRouter()
|
||||
Reference in New Issue
Block a user