Files
ThreatHunt/backend/app/api/routes/agent_v2.py
mblanke 5a2ad8ec1c feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:23:07 -05:00

507 lines
16 KiB
Python

"""API routes for analyst-assist agent v2.
Supports quick, deep, and debate modes with streaming.
Conversations are persisted to the database.
"""
import json
import logging
import re
import time
from collections import Counter
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.db.models import Conversation, Message, Dataset, KeywordTheme
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from app.agents.providers_v2 import check_all_nodes
from app.agents.registry import registry
from app.services.sans_rag import sans_rag
from app.services.scanner import KeywordScanner
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/agent", tags=["agent"])
# Global agent instance
_agent: ThreatHuntAgent | None = None
def get_agent() -> ThreatHuntAgent:
global _agent
if _agent is None:
_agent = ThreatHuntAgent()
return _agent
# Request / Response models
class AssistRequest(BaseModel):
query: str = Field(..., max_length=4000, description="Analyst question")
dataset_name: str | None = None
artifact_type: str | None = None
host_identifier: str | None = None
data_summary: str | None = None
conversation_history: list[dict] | None = None
active_hypotheses: list[str] | None = None
annotations_summary: str | None = None
enrichment_summary: str | None = None
mode: str = Field(default="quick", description="quick | deep | debate")
model_override: str | None = None
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
hunt_id: str | None = None
execution_preference: str = Field(default="auto", description="auto | force | off")
learning_mode: bool = False
class AssistResponseModel(BaseModel):
guidance: str
confidence: float
suggested_pivots: list[str]
suggested_filters: list[str]
caveats: str | None = None
reasoning: str | None = None
sans_references: list[str] = []
model_used: str = ""
node_used: str = ""
latency_ms: int = 0
perspectives: list[dict] | None = None
execution: dict | None = None
conversation_id: str | None = None
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
POLICY_QUERY_TERMS = {
"policy", "violating", "violation", "browser history", "web history",
"domain", "domains", "adult", "gambling", "piracy", "aup",
}
WEB_DATASET_HINTS = {
"web", "history", "browser", "url", "visited_url", "domain", "title",
}
def _is_policy_domain_query(query: str) -> bool:
q = (query or "").lower()
if not q:
return False
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
def _should_execute_policy_scan(request: AssistRequest) -> bool:
pref = (request.execution_preference or "auto").strip().lower()
if pref == "off":
return False
if pref == "force":
return True
return _is_policy_domain_query(request.query)
def _extract_domain(value: str | None) -> str | None:
if not value:
return None
text = value.strip()
if not text:
return None
try:
parsed = urlparse(text)
if parsed.netloc:
return parsed.netloc.lower()
except Exception:
pass
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
return m.group(0).lower() if m else None
def _dataset_score(ds: Dataset) -> int:
score = 0
name = (ds.name or "").lower()
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
for h in WEB_DATASET_HINTS:
if h in name:
score += 2
if h in cols_l:
score += 3
if h in norm_vals_l:
score += 3
if "visited_url" in cols_l or "url" in cols_l:
score += 8
if "user" in cols_l or "username" in cols_l:
score += 2
if "clientid" in cols_l or "fqdn" in cols_l:
score += 2
if (ds.row_count or 0) > 0:
score += 1
return score
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
scanner = KeywordScanner(db)
theme_result = await db.execute(
select(KeywordTheme).where(
KeywordTheme.enabled == True, # noqa: E712
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
)
)
themes = list(theme_result.scalars().all())
theme_ids = [t.id for t in themes]
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
if request.hunt_id:
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
ds_result = await db.execute(ds_query)
candidates = list(ds_result.scalars().all())
if request.dataset_name:
needle = request.dataset_name.lower().strip()
candidates = [d for d in candidates if needle in (d.name or "").lower()]
scored = sorted(
((d, _dataset_score(d)) for d in candidates),
key=lambda x: x[1],
reverse=True,
)
selected = [d for d, s in scored if s > 0][:8]
dataset_ids = [d.id for d in selected]
if not dataset_ids:
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": 0,
"dataset_names": [],
"total_hits": 0,
"policy_hits": 0,
"top_user_hosts": [],
"top_domains": [],
"sample_hits": [],
"note": "No suitable browser/web-history datasets found in current scope.",
}
result = await scanner.scan(
dataset_ids=dataset_ids,
theme_ids=theme_ids or None,
scan_hunts=False,
scan_annotations=False,
scan_messages=False,
)
hits = result.get("hits", [])
user_host_counter = Counter()
domain_counter = Counter()
for h in hits:
user = h.get("username") or "(unknown-user)"
host = h.get("hostname") or "(unknown-host)"
user_host_counter[f"{user}|{host}"] += 1
dom = _extract_domain(h.get("matched_value"))
if dom:
domain_counter[dom] += 1
top_user_hosts = [
{"user_host": k, "count": v}
for k, v in user_host_counter.most_common(10)
]
top_domains = [
{"domain": k, "count": v}
for k, v in domain_counter.most_common(10)
]
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": len(dataset_ids),
"dataset_names": [d.name for d in selected],
"total_hits": int(result.get("total_hits", 0)),
"policy_hits": int(result.get("total_hits", 0)),
"rows_scanned": int(result.get("rows_scanned", 0)),
"top_user_hosts": top_user_hosts,
"top_domains": top_domains,
"sample_hits": hits[:20],
}
# Routes
@router.post(
"/assist",
response_model=AssistResponseModel,
summary="Get analyst-assist guidance",
description="Request guidance with auto-routed model selection. "
"Supports quick (fast), deep (70B), and debate (multi-model) modes.",
)
async def agent_assist(
request: AssistRequest,
db: AsyncSession = Depends(get_db),
) -> AssistResponseModel:
try:
# Deterministic execution mode for policy-domain investigations.
if _should_execute_policy_scan(request):
t0 = time.monotonic()
exec_payload = await _run_policy_domain_execution(request, db)
latency_ms = int((time.monotonic() - t0) * 1000)
policy_hits = exec_payload.get("policy_hits", 0)
datasets_scanned = exec_payload.get("datasets_scanned", 0)
if policy_hits > 0:
guidance = (
f"Policy-violation scan complete: {policy_hits} hits across "
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
f"in execution results for triage."
)
confidence = 0.95
caveats = "Keyword-based matching can include false positives; validate with full URL context."
else:
guidance = (
f"No policy-violation hits found in current scope "
f"({datasets_scanned} dataset(s) scanned)."
)
confidence = 0.9
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
response = AssistResponseModel(
guidance=guidance,
confidence=confidence,
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
suggested_filters=[
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
"username != null",
"hostname != null",
],
caveats=caveats,
reasoning=(
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
if _is_policy_domain_query(request.query)
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
),
sans_references=["SANS FOR508", "SANS SEC504"],
model_used="execution:keyword_scanner",
node_used="local",
latency_ms=latency_ms,
execution=exec_payload,
)
conv_id = request.conversation_id
if conv_id or request.hunt_id:
conv_id = await _persist_conversation(
db,
conv_id,
request,
AgentResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
sans_references=response.sans_references,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
),
)
response.conversation_id = conv_id
return response
agent = get_agent()
context = AgentContext(
query=request.query,
dataset_name=request.dataset_name,
artifact_type=request.artifact_type,
host_identifier=request.host_identifier,
data_summary=request.data_summary,
conversation_history=request.conversation_history or [],
active_hypotheses=request.active_hypotheses or [],
annotations_summary=request.annotations_summary,
enrichment_summary=request.enrichment_summary,
mode=request.mode,
model_override=request.model_override,
learning_mode=request.learning_mode,
)
response = await agent.assist(context)
# Persist conversation
conv_id = request.conversation_id
if conv_id or request.hunt_id:
conv_id = await _persist_conversation(
db, conv_id, request, response
)
return AssistResponseModel(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
sans_references=response.sans_references,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
perspectives=[
{
"role": p.role,
"content": p.content,
"model_used": p.model_used,
"node_used": p.node_used,
"latency_ms": p.latency_ms,
}
for p in response.perspectives
] if response.perspectives else None,
execution=None,
conversation_id=conv_id,
)
except Exception as e:
logger.exception(f"Agent error: {e}")
raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")
@router.post(
"/assist/stream",
summary="Stream agent response",
description="Stream tokens via SSE for real-time display.",
)
async def agent_assist_stream(request: AssistRequest):
agent = get_agent()
context = AgentContext(
query=request.query,
dataset_name=request.dataset_name,
artifact_type=request.artifact_type,
host_identifier=request.host_identifier,
data_summary=request.data_summary,
conversation_history=request.conversation_history or [],
mode="quick", # streaming only supports quick mode
)
async def _stream():
async for token in agent.assist_stream(context):
yield f"data: {json.dumps({'token': token})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get(
"/health",
summary="Check agent and node health",
description="Returns availability of all LLM nodes and the cluster.",
)
async def agent_health() -> dict:
nodes = await check_all_nodes()
rag_health = await sans_rag.health_check()
return {
"status": "healthy",
"nodes": nodes,
"rag": rag_health,
"default_models": {
"fast": settings.DEFAULT_FAST_MODEL,
"heavy": settings.DEFAULT_HEAVY_MODEL,
"code": settings.DEFAULT_CODE_MODEL,
"vision": settings.DEFAULT_VISION_MODEL,
"embedding": settings.DEFAULT_EMBEDDING_MODEL,
},
"config": {
"max_tokens": settings.AGENT_MAX_TOKENS,
"temperature": settings.AGENT_TEMPERATURE,
},
}
@router.get(
"/models",
summary="List all available models",
description="Returns the full model registry with capabilities and node assignments.",
)
async def list_models():
return {
"models": registry.to_dict(),
"total": len(registry.models),
}
# Conversation persistence
async def _persist_conversation(
db: AsyncSession,
conversation_id: str | None,
request: AssistRequest,
response: AgentResponse,
) -> str:
"""Save user message and agent response to the database."""
if conversation_id:
# Find existing conversation
from sqlalchemy import select
result = await db.execute(
select(Conversation).where(Conversation.id == conversation_id)
)
conv = result.scalar_one_or_none()
if not conv:
conv = Conversation(id=conversation_id, hunt_id=request.hunt_id)
db.add(conv)
else:
conv = Conversation(
title=request.query[:100],
hunt_id=request.hunt_id,
)
db.add(conv)
await db.flush()
# User message
user_msg = Message(
conversation_id=conv.id,
role="user",
content=request.query,
)
db.add(user_msg)
# Agent message
agent_msg = Message(
conversation_id=conv.id,
role="agent",
content=response.guidance,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
response_meta={
"confidence": response.confidence,
"pivots": response.suggested_pivots,
"filters": response.suggested_filters,
"sans_refs": response.sans_references,
},
)
db.add(agent_msg)
await db.flush()
return conv.id