mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
version 0.4.0
This commit is contained in:
322
backend/app/services/llm_analysis.py
Normal file
322
backend/app/services/llm_analysis.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""LLM-powered dataset analysis — replaces manual IOC enrichment.
|
||||
|
||||
Loads dataset rows server-side, builds a concise summary, and sends it
|
||||
to Wile (70B heavy) or Roadrunner (fast) for threat analysis.
|
||||
Supports both single-dataset and hunt-wide analysis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.agents.providers_v2 import OllamaProvider
|
||||
from app.agents.router import TaskType, task_router
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AnalysisRequest(BaseModel):
|
||||
"""Request for LLM-powered analysis of a dataset."""
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
question: str = Field(
|
||||
default="Perform a comprehensive threat analysis of this dataset. "
|
||||
"Identify anomalies, suspicious patterns, potential IOCs, and recommend "
|
||||
"next steps for the analyst.",
|
||||
description="Specific question or general analysis request",
|
||||
)
|
||||
mode: str = Field(default="deep", description="quick | deep")
|
||||
focus: Optional[str] = Field(
|
||||
None,
|
||||
description="Focus area: threats, anomalies, lateral_movement, exfil, persistence, recon",
|
||||
)
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
"""LLM analysis result."""
|
||||
analysis: str = Field(..., description="Full analysis text (markdown)")
|
||||
confidence: float = Field(default=0.0, description="0-1 confidence")
|
||||
key_findings: list[str] = Field(default_factory=list)
|
||||
iocs_identified: list[dict] = Field(default_factory=list)
|
||||
recommended_actions: list[str] = Field(default_factory=list)
|
||||
mitre_techniques: list[str] = Field(default_factory=list)
|
||||
risk_score: int = Field(default=0, description="0-100 risk score")
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
rows_analyzed: int = 0
|
||||
dataset_summary: str = ""
|
||||
|
||||
|
||||
# ── Analysis prompts ──────────────────────────────────────────────────
|
||||
|
||||
ANALYSIS_SYSTEM = """You are an expert threat hunter and incident response analyst.
|
||||
You are analyzing CSV log data from forensic tools (Velociraptor, Sysmon, etc.).
|
||||
|
||||
Your task: Perform deep threat analysis of the data provided and produce actionable findings.
|
||||
|
||||
RESPOND WITH VALID JSON ONLY:
|
||||
{
|
||||
"analysis": "Detailed markdown analysis with headers and bullet points",
|
||||
"confidence": 0.85,
|
||||
"key_findings": ["Finding 1", "Finding 2"],
|
||||
"iocs_identified": [{"type": "ip", "value": "1.2.3.4", "context": "C2 traffic"}],
|
||||
"recommended_actions": ["Action 1", "Action 2"],
|
||||
"mitre_techniques": ["T1059.001 - PowerShell", "T1071 - Application Layer Protocol"],
|
||||
"risk_score": 65
|
||||
}
|
||||
"""
|
||||
|
||||
FOCUS_PROMPTS = {
|
||||
"threats": "Focus on identifying active threats, malware indicators, and attack patterns.",
|
||||
"anomalies": "Focus on statistical anomalies, outliers, and unusual behavior patterns.",
|
||||
"lateral_movement": "Focus on evidence of lateral movement: PsExec, WMI, RDP, SMB, pass-the-hash.",
|
||||
"exfil": "Focus on data exfiltration indicators: large transfers, DNS tunneling, unusual destinations.",
|
||||
"persistence": "Focus on persistence mechanisms: scheduled tasks, services, registry, startup items.",
|
||||
"recon": "Focus on reconnaissance activity: scanning, enumeration, discovery commands.",
|
||||
}
|
||||
|
||||
|
||||
# ── Data summarizer ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def summarize_dataset_rows(
|
||||
rows: list[dict],
|
||||
columns: list[str] | None = None,
|
||||
max_sample: int = 20,
|
||||
max_chars: int = 6000,
|
||||
) -> str:
|
||||
"""Build a concise text summary of dataset rows for LLM consumption.
|
||||
|
||||
Includes:
|
||||
- Column headers and types
|
||||
- Statistical summary (unique values, top values per column)
|
||||
- Sample rows (first N)
|
||||
- Detected patterns of interest
|
||||
"""
|
||||
if not rows:
|
||||
return "Empty dataset — no rows to analyze."
|
||||
|
||||
cols = columns or list(rows[0].keys())
|
||||
n_rows = len(rows)
|
||||
|
||||
parts: list[str] = []
|
||||
parts.append(f"## Dataset Summary: {n_rows} rows, {len(cols)} columns")
|
||||
parts.append(f"Columns: {', '.join(cols)}")
|
||||
|
||||
# Per-column stats
|
||||
parts.append("\n### Column Statistics:")
|
||||
for col in cols[:30]: # limit to first 30 cols
|
||||
values = [str(r.get(col, "")) for r in rows if r.get(col) not in (None, "", "N/A")]
|
||||
if not values:
|
||||
continue
|
||||
unique = len(set(values))
|
||||
counter = Counter(values)
|
||||
top3 = counter.most_common(3)
|
||||
top_str = ", ".join(f"{v} ({c}x)" for v, c in top3)
|
||||
parts.append(f"- **{col}**: {len(values)} non-null, {unique} unique. Top: {top_str}")
|
||||
|
||||
# Sample rows
|
||||
sample = rows[:max_sample]
|
||||
parts.append(f"\n### Sample Rows (first {len(sample)}):")
|
||||
for i, row in enumerate(sample):
|
||||
row_str = " | ".join(f"{k}={v}" for k, v in row.items() if v not in (None, "", "N/A"))
|
||||
parts.append(f"{i+1}. {row_str}")
|
||||
|
||||
# Detect interesting patterns
|
||||
patterns: list[str] = []
|
||||
all_cmds = [str(r.get("command_line", "")).lower() for r in rows if r.get("command_line")]
|
||||
sus_cmds = [c for c in all_cmds if any(
|
||||
k in c for k in ["powershell -enc", "certutil", "bitsadmin", "mshta",
|
||||
"regsvr32", "invoke-", "mimikatz", "psexec"]
|
||||
)]
|
||||
if sus_cmds:
|
||||
patterns.append(f"⚠️ {len(sus_cmds)} suspicious command lines detected")
|
||||
|
||||
all_ips = [str(r.get("dst_ip", "")) for r in rows if r.get("dst_ip")]
|
||||
ext_ips = [ip for ip in all_ips if ip and not ip.startswith(("10.", "192.168.", "172.", "127."))]
|
||||
if ext_ips:
|
||||
unique_ext = len(set(ext_ips))
|
||||
patterns.append(f"🌐 {unique_ext} unique external destination IPs")
|
||||
|
||||
if patterns:
|
||||
parts.append("\n### Detected Patterns:")
|
||||
for p in patterns:
|
||||
parts.append(f"- {p}")
|
||||
|
||||
text = "\n".join(parts)
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars] + "\n... (truncated)"
|
||||
return text
|
||||
|
||||
|
||||
# ── LLM analysis engine ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_llm_analysis(
|
||||
rows: list[dict],
|
||||
request: AnalysisRequest,
|
||||
dataset_name: str = "unknown",
|
||||
) -> AnalysisResult:
|
||||
"""Run LLM analysis on dataset rows."""
|
||||
start = time.monotonic()
|
||||
|
||||
# Build summary
|
||||
summary = summarize_dataset_rows(rows)
|
||||
|
||||
# Route to appropriate model
|
||||
task_type = TaskType.DEEP_ANALYSIS if request.mode == "deep" else TaskType.QUICK_CHAT
|
||||
decision = task_router.route(task_type)
|
||||
|
||||
# Build prompt
|
||||
focus_text = FOCUS_PROMPTS.get(request.focus or "", "")
|
||||
prompt = f"""Analyze the following forensic dataset from '{dataset_name}'.
|
||||
|
||||
{focus_text}
|
||||
|
||||
Analyst question: {request.question}
|
||||
|
||||
{summary}
|
||||
"""
|
||||
|
||||
# Enrich with SANS RAG
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
request.question,
|
||||
investigation_context=f"Analyzing {len(rows)} rows from {dataset_name}",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = task_router.get_provider(decision)
|
||||
messages = [
|
||||
{"role": "system", "content": ANALYSIS_SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
provider.generate(
|
||||
prompt=prompt,
|
||||
system=ANALYSIS_SYSTEM,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS * 2, # longer for analysis
|
||||
temperature=0.3,
|
||||
),
|
||||
timeout=300, # 5 min hard limit
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("LLM analysis timed out after 300s")
|
||||
return AnalysisResult(
|
||||
analysis="Analysis timed out after 5 minutes. Try a smaller dataset or 'quick' mode.",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM analysis failed: {e}")
|
||||
return AnalysisResult(
|
||||
analysis=f"Analysis failed: {str(e)}",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse JSON response
|
||||
result = _parse_analysis(raw)
|
||||
result.model_used = decision.model
|
||||
result.node_used = decision.node
|
||||
result.latency_ms = elapsed
|
||||
result.rows_analyzed = len(rows)
|
||||
result.dataset_summary = summary
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_analysis(raw) -> AnalysisResult:
|
||||
"""Try to parse LLM output as JSON, fall back to plain text.
|
||||
|
||||
raw may be:
|
||||
- A dict from OllamaProvider.generate() with key "response" containing LLM text
|
||||
- A plain string from other providers
|
||||
"""
|
||||
# Ollama provider returns {"response": "<llm text>", "model": ..., ...}
|
||||
if isinstance(raw, dict):
|
||||
text = raw.get("response") or raw.get("analysis") or str(raw)
|
||||
logger.info(f"_parse_analysis: extracted text from dict, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
else:
|
||||
text = str(raw)
|
||||
logger.info(f"_parse_analysis: raw is str, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Strip markdown code fences
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
lines = [l for l in lines if not l.strip().startswith("```")]
|
||||
text = "\n".join(lines).strip()
|
||||
|
||||
# Try direct JSON parse first
|
||||
for candidate in _extract_json_candidates(text):
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
logger.info(f"_parse_analysis: parsed JSON OK, keys={list(data.keys())}")
|
||||
return AnalysisResult(
|
||||
analysis=data.get("analysis", text),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
key_findings=data.get("key_findings", []),
|
||||
iocs_identified=data.get("iocs_identified", []),
|
||||
recommended_actions=data.get("recommended_actions", []),
|
||||
mitre_techniques=data.get("mitre_techniques", []),
|
||||
risk_score=int(data.get("risk_score", 0)),
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"_parse_analysis: JSON parse failed: {e}, candidate len={len(candidate)}, first 100: {candidate[:100]}")
|
||||
continue
|
||||
|
||||
# Fallback: plain text
|
||||
logger.warning(f"_parse_analysis: all JSON parse attempts failed, falling back to plain text")
|
||||
return AnalysisResult(
|
||||
analysis=text,
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_candidates(text: str):
|
||||
"""Yield JSON candidate strings from text, trying progressively more aggressive extraction."""
|
||||
import re
|
||||
|
||||
# 1. The whole text as-is
|
||||
yield text
|
||||
|
||||
# 2. Find outermost { ... } block
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
block = text[start:end + 1]
|
||||
yield block
|
||||
|
||||
# 3. Try to fix common LLM JSON issues:
|
||||
# - trailing commas before ] or }
|
||||
fixed = re.sub(r',\s*([}\]])', r'\1', block)
|
||||
if fixed != block:
|
||||
yield fixed
|
||||
Reference in New Issue
Block a user