chore: checkpoint all local changes

This commit is contained in:
2026-02-23 14:36:33 -05:00
76 changed files with 34486 additions and 738 deletions

View File

@@ -0,0 +1,404 @@
"""API routes for alerts — CRUD, analyze triggers, and alert rules."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Alert, AlertRule, _new_id, _utcnow
from app.db.repositories.datasets import DatasetRepository
from app.services.analyzers import (
get_available_analyzers,
get_analyzer,
run_all_analyzers,
AlertCandidate,
)
from app.services.process_tree import _fetch_rows
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/alerts", tags=["alerts"])
# ── Pydantic models ──────────────────────────────────────────────────
class AlertUpdate(BaseModel):
status: Optional[str] = None
severity: Optional[str] = None
assignee: Optional[str] = None
case_id: Optional[str] = None
tags: Optional[list[str]] = None
class RuleCreate(BaseModel):
name: str
description: Optional[str] = None
analyzer: str
config: Optional[dict] = None
severity_override: Optional[str] = None
enabled: bool = True
hunt_id: Optional[str] = None
class RuleUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
config: Optional[dict] = None
severity_override: Optional[str] = None
enabled: Optional[bool] = None
class AnalyzeRequest(BaseModel):
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
analyzers: Optional[list[str]] = None # None = run all
config: Optional[dict] = None
auto_create: bool = True # automatically persist alerts
# ── Helpers ───────────────────────────────────────────────────────────
def _alert_to_dict(a: Alert) -> dict:
return {
"id": a.id,
"title": a.title,
"description": a.description,
"severity": a.severity,
"status": a.status,
"analyzer": a.analyzer,
"score": a.score,
"evidence": a.evidence or [],
"mitre_technique": a.mitre_technique,
"tags": a.tags or [],
"hunt_id": a.hunt_id,
"dataset_id": a.dataset_id,
"case_id": a.case_id,
"assignee": a.assignee,
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
"created_at": a.created_at.isoformat() if a.created_at else None,
"updated_at": a.updated_at.isoformat() if a.updated_at else None,
}
def _rule_to_dict(r: AlertRule) -> dict:
return {
"id": r.id,
"name": r.name,
"description": r.description,
"analyzer": r.analyzer,
"config": r.config,
"severity_override": r.severity_override,
"enabled": r.enabled,
"hunt_id": r.hunt_id,
"created_at": r.created_at.isoformat() if r.created_at else None,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
}
# ── Alert CRUD ────────────────────────────────────────────────────────
@router.get("", summary="List alerts")
async def list_alerts(
status: str | None = Query(None),
severity: str | None = Query(None),
analyzer: str | None = Query(None),
hunt_id: str | None = Query(None),
dataset_id: str | None = Query(None),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Alert)
count_stmt = select(func.count(Alert.id))
if status:
stmt = stmt.where(Alert.status == status)
count_stmt = count_stmt.where(Alert.status == status)
if severity:
stmt = stmt.where(Alert.severity == severity)
count_stmt = count_stmt.where(Alert.severity == severity)
if analyzer:
stmt = stmt.where(Alert.analyzer == analyzer)
count_stmt = count_stmt.where(Alert.analyzer == analyzer)
if hunt_id:
stmt = stmt.where(Alert.hunt_id == hunt_id)
count_stmt = count_stmt.where(Alert.hunt_id == hunt_id)
if dataset_id:
stmt = stmt.where(Alert.dataset_id == dataset_id)
count_stmt = count_stmt.where(Alert.dataset_id == dataset_id)
total = (await db.execute(count_stmt)).scalar() or 0
results = (await db.execute(
stmt.order_by(desc(Alert.score), desc(Alert.created_at)).offset(offset).limit(limit)
)).scalars().all()
return {"alerts": [_alert_to_dict(a) for a in results], "total": total}
@router.get("/stats", summary="Alert statistics dashboard")
async def alert_stats(
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
"""Return aggregated alert statistics."""
base = select(Alert)
if hunt_id:
base = base.where(Alert.hunt_id == hunt_id)
# Severity breakdown
sev_stmt = select(Alert.severity, func.count(Alert.id)).group_by(Alert.severity)
if hunt_id:
sev_stmt = sev_stmt.where(Alert.hunt_id == hunt_id)
sev_rows = (await db.execute(sev_stmt)).all()
severity_counts = {s: c for s, c in sev_rows}
# Status breakdown
status_stmt = select(Alert.status, func.count(Alert.id)).group_by(Alert.status)
if hunt_id:
status_stmt = status_stmt.where(Alert.hunt_id == hunt_id)
status_rows = (await db.execute(status_stmt)).all()
status_counts = {s: c for s, c in status_rows}
# Analyzer breakdown
analyzer_stmt = select(Alert.analyzer, func.count(Alert.id)).group_by(Alert.analyzer)
if hunt_id:
analyzer_stmt = analyzer_stmt.where(Alert.hunt_id == hunt_id)
analyzer_rows = (await db.execute(analyzer_stmt)).all()
analyzer_counts = {a: c for a, c in analyzer_rows}
# Top MITRE techniques
mitre_stmt = (
select(Alert.mitre_technique, func.count(Alert.id))
.where(Alert.mitre_technique.isnot(None))
.group_by(Alert.mitre_technique)
.order_by(desc(func.count(Alert.id)))
.limit(10)
)
if hunt_id:
mitre_stmt = mitre_stmt.where(Alert.hunt_id == hunt_id)
mitre_rows = (await db.execute(mitre_stmt)).all()
top_mitre = [{"technique": t, "count": c} for t, c in mitre_rows]
total = sum(severity_counts.values())
return {
"total": total,
"severity_counts": severity_counts,
"status_counts": status_counts,
"analyzer_counts": analyzer_counts,
"top_mitre": top_mitre,
}
@router.get("/{alert_id}", summary="Get alert detail")
async def get_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
result = await db.get(Alert, alert_id)
if not result:
raise HTTPException(status_code=404, detail="Alert not found")
return _alert_to_dict(result)
@router.put("/{alert_id}", summary="Update alert (status, assignee, etc.)")
async def update_alert(
alert_id: str, body: AlertUpdate, db: AsyncSession = Depends(get_db)
):
alert = await db.get(Alert, alert_id)
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
if body.status is not None:
alert.status = body.status
if body.status == "acknowledged" and not alert.acknowledged_at:
alert.acknowledged_at = _utcnow()
if body.status in ("resolved", "false-positive") and not alert.resolved_at:
alert.resolved_at = _utcnow()
if body.severity is not None:
alert.severity = body.severity
if body.assignee is not None:
alert.assignee = body.assignee
if body.case_id is not None:
alert.case_id = body.case_id
if body.tags is not None:
alert.tags = body.tags
await db.commit()
await db.refresh(alert)
return _alert_to_dict(alert)
@router.delete("/{alert_id}", summary="Delete alert")
async def delete_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
alert = await db.get(Alert, alert_id)
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
await db.delete(alert)
await db.commit()
return {"ok": True}
# ── Bulk operations ──────────────────────────────────────────────────
@router.post("/bulk-update", summary="Bulk update alert statuses")
async def bulk_update_alerts(
alert_ids: list[str],
status: str = Query(...),
db: AsyncSession = Depends(get_db),
):
updated = 0
for aid in alert_ids:
alert = await db.get(Alert, aid)
if alert:
alert.status = status
if status == "acknowledged" and not alert.acknowledged_at:
alert.acknowledged_at = _utcnow()
if status in ("resolved", "false-positive") and not alert.resolved_at:
alert.resolved_at = _utcnow()
updated += 1
await db.commit()
return {"updated": updated}
# ── Run Analyzers ────────────────────────────────────────────────────
@router.get("/analyzers/list", summary="List available analyzers")
async def list_analyzers():
return {"analyzers": get_available_analyzers()}
@router.post("/analyze", summary="Run analyzers on a dataset/hunt and optionally create alerts")
async def run_analysis(
request: AnalyzeRequest, db: AsyncSession = Depends(get_db)
):
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
# Load rows
rows_objs = await _fetch_rows(
db, dataset_id=request.dataset_id, hunt_id=request.hunt_id, limit=10000,
)
if not rows_objs:
raise HTTPException(status_code=404, detail="No rows found")
rows = [r.normalized_data or r.data for r in rows_objs]
# Run analyzers
candidates = await run_all_analyzers(rows, enabled=request.analyzers, config=request.config)
created_alerts: list[dict] = []
if request.auto_create and candidates:
for c in candidates:
alert = Alert(
id=_new_id(),
title=c.title,
description=c.description,
severity=c.severity,
analyzer=c.analyzer,
score=c.score,
evidence=c.evidence,
mitre_technique=c.mitre_technique,
tags=c.tags,
hunt_id=request.hunt_id,
dataset_id=request.dataset_id,
)
db.add(alert)
created_alerts.append(_alert_to_dict(alert))
await db.commit()
return {
"candidates_found": len(candidates),
"alerts_created": len(created_alerts),
"alerts": created_alerts,
"summary": {
"by_severity": _count_by(candidates, "severity"),
"by_analyzer": _count_by(candidates, "analyzer"),
"rows_analyzed": len(rows),
},
}
def _count_by(items: list[AlertCandidate], attr: str) -> dict[str, int]:
counts: dict[str, int] = {}
for item in items:
key = getattr(item, attr, "unknown")
counts[key] = counts.get(key, 0) + 1
return counts
# ── Alert Rules CRUD ─────────────────────────────────────────────────
@router.get("/rules/list", summary="List alert rules")
async def list_rules(
enabled: bool | None = Query(None),
db: AsyncSession = Depends(get_db),
):
stmt = select(AlertRule)
if enabled is not None:
stmt = stmt.where(AlertRule.enabled == enabled)
results = (await db.execute(stmt.order_by(AlertRule.created_at))).scalars().all()
return {"rules": [_rule_to_dict(r) for r in results]}
@router.post("/rules", summary="Create alert rule")
async def create_rule(body: RuleCreate, db: AsyncSession = Depends(get_db)):
# Validate analyzer exists
if not get_analyzer(body.analyzer):
raise HTTPException(status_code=400, detail=f"Unknown analyzer: {body.analyzer}")
rule = AlertRule(
id=_new_id(),
name=body.name,
description=body.description,
analyzer=body.analyzer,
config=body.config,
severity_override=body.severity_override,
enabled=body.enabled,
hunt_id=body.hunt_id,
)
db.add(rule)
await db.commit()
await db.refresh(rule)
return _rule_to_dict(rule)
@router.put("/rules/{rule_id}", summary="Update alert rule")
async def update_rule(
rule_id: str, body: RuleUpdate, db: AsyncSession = Depends(get_db)
):
rule = await db.get(AlertRule, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Rule not found")
if body.name is not None:
rule.name = body.name
if body.description is not None:
rule.description = body.description
if body.config is not None:
rule.config = body.config
if body.severity_override is not None:
rule.severity_override = body.severity_override
if body.enabled is not None:
rule.enabled = body.enabled
await db.commit()
await db.refresh(rule)
return _rule_to_dict(rule)
@router.delete("/rules/{rule_id}", summary="Delete alert rule")
async def delete_rule(rule_id: str, db: AsyncSession = Depends(get_db)):
rule = await db.get(AlertRule, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Rule not found")
await db.delete(rule)
await db.commit()
return {"ok": True}

View File

@@ -1,371 +1,296 @@
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
host grouping, anomaly detection, data query (SSE), and job management."""
from __future__ import annotations
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
import logging
from typing import Optional
from typing import Any, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import select
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import HostProfile, HuntReport, TriageResult
from app.db.repositories.datasets import DatasetRepository
from app.services.process_tree import (
build_process_tree,
build_storyline,
compute_risk_scores,
_fetch_rows,
)
from app.services.llm_analysis import (
AnalysisRequest,
AnalysisResult,
run_llm_analysis,
)
from app.services.timeline import (
build_timeline_bins,
compute_field_stats,
search_rows,
)
from app.services.mitre import (
map_to_attack,
build_knowledge_graph,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
# --- Response models ---
class TriageResultResponse(BaseModel):
id: str
dataset_id: str
row_start: int
row_end: int
risk_score: float
verdict: str
findings: list | None = None
suspicious_indicators: list | None = None
mitre_techniques: list | None = None
model_used: str | None = None
node_used: str | None = None
class Config:
from_attributes = True
# ── Response models ───────────────────────────────────────────────────
class HostProfileResponse(BaseModel):
id: str
hunt_id: str
class ProcessTreeResponse(BaseModel):
trees: list[dict] = Field(default_factory=list)
total_processes: int = 0
class StorylineResponse(BaseModel):
nodes: list[dict] = Field(default_factory=list)
edges: list[dict] = Field(default_factory=list)
summary: dict = Field(default_factory=dict)
class RiskHostEntry(BaseModel):
hostname: str
fqdn: str | None = None
risk_score: float
risk_level: str
artifact_summary: dict | None = None
timeline_summary: str | None = None
suspicious_findings: list | None = None
mitre_techniques: list | None = None
llm_analysis: str | None = None
model_used: str | None = None
class Config:
from_attributes = True
score: int = 0
signals: list[str] = Field(default_factory=list)
event_count: int = 0
process_count: int = 0
network_count: int = 0
file_count: int = 0
class HuntReportResponse(BaseModel):
id: str
hunt_id: str
status: str
exec_summary: str | None = None
full_report: str | None = None
findings: list | None = None
recommendations: list | None = None
mitre_mapping: dict | None = None
ioc_table: list | None = None
host_risk_summary: list | None = None
models_used: list | None = None
generation_time_ms: int | None = None
class Config:
from_attributes = True
class RiskSummaryResponse(BaseModel):
hosts: list[RiskHostEntry] = Field(default_factory=list)
overall_score: int = 0
total_events: int = 0
severity_breakdown: dict[str, int] = Field(default_factory=dict)
class QueryRequest(BaseModel):
question: str
mode: str = "quick" # quick or deep
# ── Routes ────────────────────────────────────────────────────────────
# --- Triage endpoints ---
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
async def get_triage_results(
dataset_id: str,
min_risk: float = Query(0.0, ge=0.0, le=10.0),
@router.get(
"/process-tree",
response_model=ProcessTreeResponse,
summary="Build process tree from dataset rows",
description=(
"Extracts parent→child process relationships from dataset rows "
"and returns a hierarchical forest of process nodes."
),
)
async def get_process_tree(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(TriageResult)
.where(TriageResult.dataset_id == dataset_id)
.where(TriageResult.risk_score >= min_risk)
.order_by(TriageResult.risk_score.desc())
"""Return process tree(s) for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
trees = await build_process_tree(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
return result.scalars().all()
# Count total processes recursively
def _count(node: dict) -> int:
return 1 + sum(_count(c) for c in node.get("children", []))
total = sum(_count(t) for t in trees)
return ProcessTreeResponse(trees=trees, total_processes=total)
@router.post("/triage/{dataset_id}")
async def trigger_triage(
dataset_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.triage import triage_dataset
await triage_dataset(dataset_id)
background_tasks.add_task(_run)
return {"status": "triage_started", "dataset_id": dataset_id}
# --- Host profile endpoints ---
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
async def get_host_profiles(
hunt_id: str,
min_risk: float = Query(0.0, ge=0.0, le=10.0),
@router.get(
"/storyline",
response_model=StorylineResponse,
summary="Build CrowdStrike-style storyline attack graph",
description=(
"Creates a Cytoscape-compatible graph of events connected by "
"process lineage (spawned) and temporal sequence within each host."
),
)
async def get_storyline(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HostProfile)
.where(HostProfile.hunt_id == hunt_id)
.where(HostProfile.risk_score >= min_risk)
.order_by(HostProfile.risk_score.desc())
"""Return a storyline graph for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
result = await build_storyline(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
return result.scalars().all()
return StorylineResponse(**result)
@router.post("/profiles/{hunt_id}")
async def trigger_all_profiles(
hunt_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.host_profiler import profile_all_hosts
await profile_all_hosts(hunt_id)
background_tasks.add_task(_run)
return {"status": "profiling_started", "hunt_id": hunt_id}
@router.post("/profiles/{hunt_id}/{hostname}")
async def trigger_single_profile(
hunt_id: str,
hostname: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.host_profiler import profile_host
await profile_host(hunt_id, hostname)
background_tasks.add_task(_run)
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
# --- Report endpoints ---
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
async def list_reports(
hunt_id: str,
@router.get(
"/risk-summary",
response_model=RiskSummaryResponse,
summary="Compute risk scores per host",
description=(
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
"credential dumping, lateral movement) and produces per-host risk scores."
),
)
async def get_risk_summary(
hunt_id: str | None = Query(None, description="Hunt ID"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HuntReport)
.where(HuntReport.hunt_id == hunt_id)
.order_by(HuntReport.created_at.desc())
"""Return risk scores for all hosts in a hunt."""
result = await compute_risk_scores(db, hunt_id=hunt_id)
return RiskSummaryResponse(**result)
# ── LLM Analysis ─────────────────────────────────────────────────────
@router.post(
"/llm-analyze",
response_model=AnalysisResult,
summary="Run LLM-powered threat analysis on dataset",
description=(
"Loads dataset rows server-side, builds a summary, and sends to "
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
),
)
async def llm_analyze(
request: AnalysisRequest,
db: AsyncSession = Depends(get_db),
):
"""Run LLM analysis on a dataset or hunt."""
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
# Load rows
rows_objs = await _fetch_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
limit=2000,
)
return result.scalars().all()
if not rows_objs:
raise HTTPException(status_code=404, detail="No rows found for analysis")
# Extract data dicts
rows = [r.normalized_data or r.data for r in rows_objs]
# Get dataset name
ds_name = "hunt datasets"
if request.dataset_id:
repo = DatasetRepository(db)
ds = await repo.get_dataset(request.dataset_id)
if ds:
ds_name = ds.name
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
return result
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
async def get_report(
hunt_id: str,
report_id: str,
# ── Timeline ──────────────────────────────────────────────────────────
@router.get(
"/timeline",
summary="Get event timeline histogram bins",
)
async def get_timeline(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
bins: int = Query(60, ge=10, le=200),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HuntReport)
.where(HuntReport.id == report_id)
.where(HuntReport.hunt_id == hunt_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
return report
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
@router.post("/reports/{hunt_id}/generate")
async def trigger_report(
hunt_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.report_generator import generate_report
await generate_report(hunt_id)
background_tasks.add_task(_run)
return {"status": "report_generation_started", "hunt_id": hunt_id}
# --- IOC extraction endpoints ---
@router.get("/iocs/{dataset_id}")
async def extract_iocs(
dataset_id: str,
max_rows: int = Query(5000, ge=1, le=50000),
@router.get(
"/field-stats",
summary="Get per-field value distributions",
)
async def get_field_stats(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
fields: str | None = Query(None, description="Comma-separated field names"),
top_n: int = Query(20, ge=5, le=100),
db: AsyncSession = Depends(get_db),
):
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
from app.services.ioc_extractor import extract_iocs_from_dataset
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
total = sum(len(v) for v in iocs.values())
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
# --- Host grouping endpoints ---
@router.get("/hosts/{hunt_id}")
async def get_host_groups(
hunt_id: str,
db: AsyncSession = Depends(get_db),
):
"""Group data by hostname across all datasets in a hunt."""
from app.services.ioc_extractor import extract_host_groups
groups = await extract_host_groups(hunt_id, db)
return {"hunt_id": hunt_id, "hosts": groups}
# --- Anomaly detection endpoints ---
@router.get("/anomalies/{dataset_id}")
async def get_anomalies(
dataset_id: str,
outliers_only: bool = Query(False),
db: AsyncSession = Depends(get_db),
):
"""Get anomaly detection results for a dataset."""
from app.db.models import AnomalyResult
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
if outliers_only:
stmt = stmt.where(AnomalyResult.is_outlier == True)
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
result = await db.execute(stmt)
rows = result.scalars().all()
return [
{
"id": r.id,
"dataset_id": r.dataset_id,
"row_id": r.row_id,
"anomaly_score": r.anomaly_score,
"distance_from_centroid": r.distance_from_centroid,
"cluster_id": r.cluster_id,
"is_outlier": r.is_outlier,
"explanation": r.explanation,
}
for r in rows
]
@router.post("/anomalies/{dataset_id}")
async def trigger_anomaly_detection(
dataset_id: str,
k: int = Query(3, ge=2, le=20),
threshold: float = Query(0.35, ge=0.1, le=0.9),
background_tasks: BackgroundTasks = None,
):
"""Trigger embedding-based anomaly detection on a dataset."""
async def _run():
from app.services.anomaly_detector import detect_anomalies
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
if background_tasks:
background_tasks.add_task(_run)
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
else:
from app.services.anomaly_detector import detect_anomalies
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
# --- Natural language data query (SSE streaming) ---
@router.post("/query/{dataset_id}")
async def query_dataset_endpoint(
dataset_id: str,
body: QueryRequest,
):
"""Ask a natural language question about a dataset.
Returns an SSE stream with token-by-token LLM response.
Event types: status, metadata, token, error, done
"""
from app.services.data_query import query_dataset_stream
return StreamingResponse(
query_dataset_stream(dataset_id, body.question, body.mode),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
field_list = [f.strip() for f in fields.split(",")] if fields else None
return await compute_field_stats(
db, dataset_id=dataset_id, hunt_id=hunt_id,
fields=field_list, top_n=top_n,
)
@router.post("/query/{dataset_id}/sync")
async def query_dataset_sync(
dataset_id: str,
body: QueryRequest,
class SearchRequest(BaseModel):
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
query: str = ""
filters: dict[str, str] = Field(default_factory=dict)
time_start: Optional[str] = None
time_end: Optional[str] = None
limit: int = 500
offset: int = 0
@router.post(
"/search",
summary="Search and filter dataset rows",
)
async def search_dataset_rows(
request: SearchRequest,
db: AsyncSession = Depends(get_db),
):
"""Non-streaming version of data query."""
from app.services.data_query import query_dataset
try:
answer = await query_dataset(dataset_id, body.question, body.mode)
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Query failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await search_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
query=request.query,
filters=request.filters,
time_start=request.time_start,
time_end=request.time_end,
limit=request.limit,
offset=request.offset,
)
# --- Job queue endpoints ---
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
@router.get("/jobs")
async def list_jobs(
status: str | None = Query(None),
job_type: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
@router.get(
"/mitre-map",
summary="Map dataset events to MITRE ATT&CK techniques",
)
async def get_mitre_map(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
"""List all tracked jobs."""
from app.services.job_queue import job_queue, JobStatus, JobType
s = JobStatus(status) if status else None
t = JobType(job_type) if job_type else None
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
stats = job_queue.get_stats()
return {"jobs": jobs, "stats": stats}
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
@router.get("/jobs/{job_id}")
async def get_job(job_id: str):
"""Get status of a specific job."""
from app.services.job_queue import job_queue
job = job_queue.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return job.to_dict()
@router.delete("/jobs/{job_id}")
async def cancel_job(job_id: str):
"""Cancel a running or queued job."""
from app.services.job_queue import job_queue
if job_queue.cancel_job(job_id):
return {"status": "cancelled", "job_id": job_id}
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
@router.post("/jobs/submit/{job_type}")
async def submit_job(
job_type: str,
params: dict = {},
@router.get(
"/knowledge-graph",
summary="Build entity-technique knowledge graph",
)
async def get_knowledge_graph(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
<<<<<<< HEAD
"""Submit a new job to the queue.
Job types: triage, host_profile, report, anomaly, query
@@ -403,4 +328,9 @@ async def lb_health_check():
"""Force a health check of both nodes."""
from app.services.load_balancer import lb
await lb.check_health()
return lb.get_status()
return lb.get_status()
=======
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -0,0 +1,296 @@
"""API routes for case management — CRUD for cases, tasks, and activity logs."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Case, CaseTask, ActivityLog, _new_id, _utcnow
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/cases", tags=["cases"])
# ── Pydantic models ──────────────────────────────────────────────────
class CaseCreate(BaseModel):
title: str
description: Optional[str] = None
severity: str = "medium"
tlp: str = "amber"
pap: str = "amber"
priority: int = 2
assignee: Optional[str] = None
tags: Optional[list[str]] = None
hunt_id: Optional[str] = None
mitre_techniques: Optional[list[str]] = None
iocs: Optional[list[dict]] = None
class CaseUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
severity: Optional[str] = None
tlp: Optional[str] = None
pap: Optional[str] = None
status: Optional[str] = None
priority: Optional[int] = None
assignee: Optional[str] = None
tags: Optional[list[str]] = None
mitre_techniques: Optional[list[str]] = None
iocs: Optional[list[dict]] = None
class TaskCreate(BaseModel):
title: str
description: Optional[str] = None
assignee: Optional[str] = None
class TaskUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
status: Optional[str] = None
assignee: Optional[str] = None
order: Optional[int] = None
# ── Helper: log activity ─────────────────────────────────────────────
async def _log_activity(
db: AsyncSession,
entity_type: str,
entity_id: str,
action: str,
details: dict | None = None,
):
log = ActivityLog(
entity_type=entity_type,
entity_id=entity_id,
action=action,
details=details,
created_at=_utcnow(),
)
db.add(log)
# ── Case CRUD ─────────────────────────────────────────────────────────
@router.post("", summary="Create a case")
async def create_case(body: CaseCreate, db: AsyncSession = Depends(get_db)):
now = _utcnow()
case = Case(
id=_new_id(),
title=body.title,
description=body.description,
severity=body.severity,
tlp=body.tlp,
pap=body.pap,
priority=body.priority,
assignee=body.assignee,
tags=body.tags,
hunt_id=body.hunt_id,
mitre_techniques=body.mitre_techniques,
iocs=body.iocs,
created_at=now,
updated_at=now,
)
db.add(case)
await _log_activity(db, "case", case.id, "created", {"title": body.title})
await db.commit()
await db.refresh(case)
return _case_to_dict(case)
@router.get("", summary="List cases")
async def list_cases(
status: Optional[str] = Query(None),
hunt_id: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
q = select(Case).order_by(desc(Case.updated_at))
if status:
q = q.where(Case.status == status)
if hunt_id:
q = q.where(Case.hunt_id == hunt_id)
q = q.offset(offset).limit(limit)
result = await db.execute(q)
cases = result.scalars().all()
count_q = select(func.count(Case.id))
if status:
count_q = count_q.where(Case.status == status)
if hunt_id:
count_q = count_q.where(Case.hunt_id == hunt_id)
total = (await db.execute(count_q)).scalar() or 0
return {"cases": [_case_to_dict(c) for c in cases], "total": total}
@router.get("/{case_id}", summary="Get case detail")
async def get_case(case_id: str, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
return _case_to_dict(case)
@router.put("/{case_id}", summary="Update a case")
async def update_case(case_id: str, body: CaseUpdate, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
changes = {}
for field in ["title", "description", "severity", "tlp", "pap", "status",
"priority", "assignee", "tags", "mitre_techniques", "iocs"]:
val = getattr(body, field)
if val is not None:
old = getattr(case, field)
setattr(case, field, val)
changes[field] = {"old": old, "new": val}
if "status" in changes and changes["status"]["new"] == "in-progress" and not case.started_at:
case.started_at = _utcnow()
if "status" in changes and changes["status"]["new"] in ("resolved", "closed"):
case.resolved_at = _utcnow()
case.updated_at = _utcnow()
await _log_activity(db, "case", case.id, "updated", changes)
await db.commit()
await db.refresh(case)
return _case_to_dict(case)
@router.delete("/{case_id}", summary="Delete a case")
async def delete_case(case_id: str, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
await db.delete(case)
await db.commit()
return {"deleted": True}
# ── Task CRUD ─────────────────────────────────────────────────────────
@router.post("/{case_id}/tasks", summary="Add task to case")
async def create_task(case_id: str, body: TaskCreate, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
now = _utcnow()
task = CaseTask(
id=_new_id(),
case_id=case_id,
title=body.title,
description=body.description,
assignee=body.assignee,
created_at=now,
updated_at=now,
)
db.add(task)
await _log_activity(db, "case", case_id, "task_created", {"title": body.title})
await db.commit()
await db.refresh(task)
return _task_to_dict(task)
@router.put("/{case_id}/tasks/{task_id}", summary="Update a task")
async def update_task(case_id: str, task_id: str, body: TaskUpdate, db: AsyncSession = Depends(get_db)):
task = await db.get(CaseTask, task_id)
if not task or task.case_id != case_id:
raise HTTPException(status_code=404, detail="Task not found")
for field in ["title", "description", "status", "assignee", "order"]:
val = getattr(body, field)
if val is not None:
setattr(task, field, val)
task.updated_at = _utcnow()
await _log_activity(db, "case", case_id, "task_updated", {"task_id": task_id})
await db.commit()
await db.refresh(task)
return _task_to_dict(task)
@router.delete("/{case_id}/tasks/{task_id}", summary="Delete a task")
async def delete_task(case_id: str, task_id: str, db: AsyncSession = Depends(get_db)):
task = await db.get(CaseTask, task_id)
if not task or task.case_id != case_id:
raise HTTPException(status_code=404, detail="Task not found")
await db.delete(task)
await db.commit()
return {"deleted": True}
# ── Activity Log ──────────────────────────────────────────────────────
@router.get("/{case_id}/activity", summary="Get case activity log")
async def get_activity(
case_id: str,
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
):
q = (
select(ActivityLog)
.where(ActivityLog.entity_type == "case", ActivityLog.entity_id == case_id)
.order_by(desc(ActivityLog.created_at))
.limit(limit)
)
result = await db.execute(q)
logs = result.scalars().all()
return {
"logs": [
{
"id": l.id,
"action": l.action,
"details": l.details,
"user_id": l.user_id,
"created_at": l.created_at.isoformat() if l.created_at else None,
}
for l in logs
]
}
# ── Helpers ───────────────────────────────────────────────────────────
def _case_to_dict(c: Case) -> dict:
return {
"id": c.id,
"title": c.title,
"description": c.description,
"severity": c.severity,
"tlp": c.tlp,
"pap": c.pap,
"status": c.status,
"priority": c.priority,
"assignee": c.assignee,
"tags": c.tags or [],
"hunt_id": c.hunt_id,
"owner_id": c.owner_id,
"mitre_techniques": c.mitre_techniques or [],
"iocs": c.iocs or [],
"started_at": c.started_at.isoformat() if c.started_at else None,
"resolved_at": c.resolved_at.isoformat() if c.resolved_at else None,
"created_at": c.created_at.isoformat() if c.created_at else None,
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
"tasks": [_task_to_dict(t) for t in (c.tasks or [])],
}
def _task_to_dict(t: CaseTask) -> dict:
return {
"id": t.id,
"case_id": t.case_id,
"title": t.title,
"description": t.description,
"status": t.status,
"assignee": t.assignee,
"order": t.order,
"created_at": t.created_at.isoformat() if t.created_at else None,
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
}

View File

@@ -398,3 +398,30 @@ async def delete_dataset(
raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id}
@router.post(
"/rescan-ioc",
summary="Re-scan IOC columns for all datasets",
)
async def rescan_ioc_columns(
db: AsyncSession = Depends(get_db),
):
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
repo = DatasetRepository(db)
all_ds = await repo.list_datasets(limit=10000)
updated = 0
for ds in all_ds:
columns = list((ds.column_schema or {}).keys())
if not columns:
continue
new_ioc = detect_ioc_columns(
columns,
ds.column_schema or {},
ds.normalized_columns or {},
)
if new_ioc != (ds.ioc_columns or {}):
ds.ioc_columns = new_ioc
updated += 1
await db.commit()
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}

View File

@@ -1,20 +1,34 @@
<<<<<<< HEAD
"""Network topology API - host inventory endpoint with background caching."""
=======
"""API routes for Network Picture — deduplicated host inventory."""
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
<<<<<<< HEAD
from fastapi.responses import JSONResponse
=======
from pydantic import BaseModel, Field
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
<<<<<<< HEAD
from app.services.host_inventory import build_host_inventory, inventory_cache
from app.services.job_queue import job_queue, JobType
=======
from app.services.network_inventory import build_network_picture
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/network", tags=["network"])
<<<<<<< HEAD
@router.get("/host-inventory")
async def get_host_inventory(
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
@@ -173,3 +187,58 @@ async def trigger_rebuild(
inventory_cache.invalidate(hunt_id)
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return {"job_id": job.id, "status": "queued"}
=======
# ── Response models ───────────────────────────────────────────────────
class HostEntry(BaseModel):
hostname: str
ips: list[str] = Field(default_factory=list)
users: list[str] = Field(default_factory=list)
os: list[str] = Field(default_factory=list)
mac_addresses: list[str] = Field(default_factory=list)
protocols: list[str] = Field(default_factory=list)
open_ports: list[str] = Field(default_factory=list)
remote_targets: list[str] = Field(default_factory=list)
datasets: list[str] = Field(default_factory=list)
connection_count: int = 0
first_seen: str | None = None
last_seen: str | None = None
class PictureSummary(BaseModel):
total_hosts: int = 0
total_connections: int = 0
total_unique_ips: int = 0
datasets_scanned: int = 0
class NetworkPictureResponse(BaseModel):
hosts: list[HostEntry]
summary: PictureSummary
# ── Routes ────────────────────────────────────────────────────────────
@router.get(
"/picture",
response_model=NetworkPictureResponse,
summary="Build deduplicated host inventory for a hunt",
description=(
"Scans all datasets in the specified hunt, extracts host-identifying "
"fields (hostname, IP, username, OS, MAC, ports), deduplicates by "
"hostname, and returns a clean one-row-per-host network picture."
),
)
async def get_network_picture(
hunt_id: str = Query(..., description="Hunt ID to scan"),
db: AsyncSession = Depends(get_db),
):
"""Return a deduplicated network picture for a hunt."""
if not hunt_id:
raise HTTPException(status_code=400, detail="hunt_id is required")
result = await build_network_picture(db, hunt_id)
return result
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -0,0 +1,360 @@
"""API routes for investigation notebooks and playbooks."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Notebook, PlaybookRun, _new_id, _utcnow
from app.services.playbook import (
get_builtin_playbooks,
get_playbook_template,
validate_notebook_cells,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/notebooks", tags=["notebooks"])
# ── Pydantic models ──────────────────────────────────────────────────
class NotebookCreate(BaseModel):
title: str
description: Optional[str] = None
cells: Optional[list[dict]] = None
hunt_id: Optional[str] = None
case_id: Optional[str] = None
tags: Optional[list[str]] = None
class NotebookUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
cells: Optional[list[dict]] = None
tags: Optional[list[str]] = None
class CellUpdate(BaseModel):
"""Update a single cell or add a new one."""
cell_id: str
cell_type: Optional[str] = None
source: Optional[str] = None
output: Optional[str] = None
metadata: Optional[dict] = None
class PlaybookStart(BaseModel):
playbook_name: str
hunt_id: Optional[str] = None
case_id: Optional[str] = None
started_by: Optional[str] = None
class StepComplete(BaseModel):
notes: Optional[str] = None
status: str = "completed" # completed | skipped
# ── Helpers ───────────────────────────────────────────────────────────
def _notebook_to_dict(nb: Notebook) -> dict:
return {
"id": nb.id,
"title": nb.title,
"description": nb.description,
"cells": nb.cells or [],
"hunt_id": nb.hunt_id,
"case_id": nb.case_id,
"owner_id": nb.owner_id,
"tags": nb.tags or [],
"cell_count": len(nb.cells or []),
"created_at": nb.created_at.isoformat() if nb.created_at else None,
"updated_at": nb.updated_at.isoformat() if nb.updated_at else None,
}
def _run_to_dict(run: PlaybookRun) -> dict:
return {
"id": run.id,
"playbook_name": run.playbook_name,
"status": run.status,
"current_step": run.current_step,
"total_steps": run.total_steps,
"step_results": run.step_results or [],
"hunt_id": run.hunt_id,
"case_id": run.case_id,
"started_by": run.started_by,
"created_at": run.created_at.isoformat() if run.created_at else None,
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
"completed_at": run.completed_at.isoformat() if run.completed_at else None,
}
# ── Notebook CRUD ─────────────────────────────────────────────────────
@router.get("", summary="List notebooks")
async def list_notebooks(
hunt_id: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Notebook)
count_stmt = select(func.count(Notebook.id))
if hunt_id:
stmt = stmt.where(Notebook.hunt_id == hunt_id)
count_stmt = count_stmt.where(Notebook.hunt_id == hunt_id)
total = (await db.execute(count_stmt)).scalar() or 0
results = (await db.execute(
stmt.order_by(desc(Notebook.updated_at)).offset(offset).limit(limit)
)).scalars().all()
return {"notebooks": [_notebook_to_dict(n) for n in results], "total": total}
@router.get("/{notebook_id}", summary="Get notebook")
async def get_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
return _notebook_to_dict(nb)
@router.post("", summary="Create notebook")
async def create_notebook(body: NotebookCreate, db: AsyncSession = Depends(get_db)):
cells = validate_notebook_cells(body.cells or [])
if not cells:
# Start with a default markdown cell
cells = [{"id": "cell-0", "cell_type": "markdown", "source": "# Investigation Notes\n\nStart documenting your findings here.", "output": None, "metadata": {}}]
nb = Notebook(
id=_new_id(),
title=body.title,
description=body.description,
cells=cells,
hunt_id=body.hunt_id,
case_id=body.case_id,
tags=body.tags,
)
db.add(nb)
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.put("/{notebook_id}", summary="Update notebook")
async def update_notebook(
notebook_id: str, body: NotebookUpdate, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
if body.title is not None:
nb.title = body.title
if body.description is not None:
nb.description = body.description
if body.cells is not None:
nb.cells = validate_notebook_cells(body.cells)
if body.tags is not None:
nb.tags = body.tags
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.post("/{notebook_id}/cells", summary="Add or update a cell")
async def upsert_cell(
notebook_id: str, body: CellUpdate, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
cells = list(nb.cells or [])
found = False
for i, c in enumerate(cells):
if c.get("id") == body.cell_id:
if body.cell_type is not None:
cells[i]["cell_type"] = body.cell_type
if body.source is not None:
cells[i]["source"] = body.source
if body.output is not None:
cells[i]["output"] = body.output
if body.metadata is not None:
cells[i]["metadata"] = body.metadata
found = True
break
if not found:
cells.append({
"id": body.cell_id,
"cell_type": body.cell_type or "markdown",
"source": body.source or "",
"output": body.output,
"metadata": body.metadata or {},
})
nb.cells = cells
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.delete("/{notebook_id}/cells/{cell_id}", summary="Delete a cell")
async def delete_cell(
notebook_id: str, cell_id: str, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
cells = [c for c in (nb.cells or []) if c.get("id") != cell_id]
nb.cells = cells
await db.commit()
return {"ok": True, "remaining_cells": len(cells)}
@router.delete("/{notebook_id}", summary="Delete notebook")
async def delete_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
await db.delete(nb)
await db.commit()
return {"ok": True}
# ── Playbooks ─────────────────────────────────────────────────────────
@router.get("/playbooks/templates", summary="List built-in playbook templates")
async def list_playbook_templates():
templates = get_builtin_playbooks()
return {
"templates": [
{
"name": t["name"],
"description": t["description"],
"category": t["category"],
"tags": t["tags"],
"step_count": len(t["steps"]),
}
for t in templates
]
}
@router.get("/playbooks/templates/{name}", summary="Get playbook template detail")
async def get_playbook_template_detail(name: str):
template = get_playbook_template(name)
if not template:
raise HTTPException(status_code=404, detail="Playbook template not found")
return template
@router.post("/playbooks/start", summary="Start a playbook run")
async def start_playbook(body: PlaybookStart, db: AsyncSession = Depends(get_db)):
template = get_playbook_template(body.playbook_name)
if not template:
raise HTTPException(status_code=404, detail="Playbook template not found")
run = PlaybookRun(
id=_new_id(),
playbook_name=body.playbook_name,
status="in-progress",
current_step=1,
total_steps=len(template["steps"]),
step_results=[],
hunt_id=body.hunt_id,
case_id=body.case_id,
started_by=body.started_by,
)
db.add(run)
await db.commit()
await db.refresh(run)
return _run_to_dict(run)
@router.get("/playbooks/runs", summary="List playbook runs")
async def list_playbook_runs(
status: str | None = Query(None),
hunt_id: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
):
stmt = select(PlaybookRun)
if status:
stmt = stmt.where(PlaybookRun.status == status)
if hunt_id:
stmt = stmt.where(PlaybookRun.hunt_id == hunt_id)
results = (await db.execute(
stmt.order_by(desc(PlaybookRun.created_at)).limit(limit)
)).scalars().all()
return {"runs": [_run_to_dict(r) for r in results]}
@router.get("/playbooks/runs/{run_id}", summary="Get playbook run detail")
async def get_playbook_run(run_id: str, db: AsyncSession = Depends(get_db)):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
# Also include the template steps
template = get_playbook_template(run.playbook_name)
result = _run_to_dict(run)
result["steps"] = template["steps"] if template else []
return result
@router.post("/playbooks/runs/{run_id}/complete-step", summary="Complete current playbook step")
async def complete_step(
run_id: str, body: StepComplete, db: AsyncSession = Depends(get_db)
):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
if run.status != "in-progress":
raise HTTPException(status_code=400, detail="Run is not in progress")
step_results = list(run.step_results or [])
step_results.append({
"step": run.current_step,
"status": body.status,
"notes": body.notes,
"completed_at": _utcnow().isoformat(),
})
run.step_results = step_results
if run.current_step >= run.total_steps:
run.status = "completed"
run.completed_at = _utcnow()
else:
run.current_step += 1
await db.commit()
await db.refresh(run)
return _run_to_dict(run)
@router.post("/playbooks/runs/{run_id}/abort", summary="Abort a playbook run")
async def abort_run(run_id: str, db: AsyncSession = Depends(get_db)):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
run.status = "aborted"
run.completed_at = _utcnow()
await db.commit()
return _run_to_dict(run)

View File

@@ -15,7 +15,7 @@ class AppConfig(BaseSettings):
# -- General --------------------------------------------------------
APP_NAME: str = "ThreatHunt"
APP_VERSION: str = "0.3.0"
APP_VERSION: str = "0.4.0"
DEBUG: bool = Field(default=False, description="Enable debug mode")
# -- Database -------------------------------------------------------

View File

@@ -75,8 +75,20 @@ async def get_db() -> AsyncSession: # type: ignore[misc]
async def init_db() -> None:
"""Create all tables (for dev / first-run). In production use Alembic."""
from sqlalchemy import inspect as sa_inspect
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Only create tables that don't already exist (safe alongside Alembic)
def _create_missing(sync_conn):
inspector = sa_inspect(sync_conn)
existing = set(inspector.get_table_names())
tables_to_create = [
t for t in Base.metadata.sorted_tables
if t.name not in existing
]
Base.metadata.create_all(sync_conn, tables=tables_to_create)
await conn.run_sync(_create_missing)
async def dispose_db() -> None:

View File

@@ -1,7 +1,7 @@
"""SQLAlchemy ORM models for ThreatHunt.
All persistent entities: datasets, hunts, conversations, annotations,
hypotheses, enrichment results, users, and AI analysis tables.
hypotheses, enrichment results, and users.
"""
import uuid
@@ -32,7 +32,8 @@ def _new_id() -> str:
return uuid.uuid4().hex
# -- Users ---
# ── Users ──────────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
@@ -41,16 +42,18 @@ class User(Base):
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(16), default="analyst")
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
# -- Hunts ---
# ── Hunts ──────────────────────────────────────────────────────────────
class Hunt(Base):
__tablename__ = "hunts"
@@ -58,7 +61,7 @@ class Hunt(Base):
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(32), default="active")
status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
owner_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
@@ -67,15 +70,15 @@ class Hunt(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
# relationships
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
host_profiles: Mapped[list["HostProfile"]] = relationship(back_populates="hunt", lazy="noload")
reports: Mapped[list["HuntReport"]] = relationship(back_populates="hunt", lazy="noload")
# -- Datasets ---
# ── Datasets ───────────────────────────────────────────────────────────
class Dataset(Base):
__tablename__ = "datasets"
@@ -83,44 +86,36 @@ class Dataset(Base):
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
filename: Mapped[str] = mapped_column(String(512), nullable=False)
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
row_count: Mapped[int] = mapped_column(Integer, default=0)
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
# New Phase 1-2 columns
processing_status: Mapped[str] = mapped_column(String(20), default="ready")
artifact_type: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
rows: Mapped[list["DatasetRow"]] = relationship(
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
)
triage_results: Mapped[list["TriageResult"]] = relationship(
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
)
__table_args__ = (
Index("ix_datasets_hunt", "hunt_id"),
Index("ix_datasets_status", "processing_status"),
)
class DatasetRow(Base):
"""Individual row from a CSV dataset, stored as JSON blob."""
__tablename__ = "dataset_rows"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
@@ -131,6 +126,7 @@ class DatasetRow(Base):
data: Mapped[dict] = mapped_column(JSON, nullable=False)
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
# relationships
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
annotations: Mapped[list["Annotation"]] = relationship(
back_populates="row", lazy="noload"
@@ -142,7 +138,8 @@ class DatasetRow(Base):
)
# -- Conversations ---
# ── Conversations ─────────────────────────────────────────────────────
class Conversation(Base):
__tablename__ = "conversations"
@@ -160,6 +157,7 @@ class Conversation(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
# relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
messages: Mapped[list["Message"]] = relationship(
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
@@ -174,15 +172,16 @@ class Message(Base):
conversation_id: Mapped[str] = mapped_column(
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
)
role: Mapped[str] = mapped_column(String(16), nullable=False)
role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
content: Mapped[str] = mapped_column(Text, nullable=False)
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
__table_args__ = (
@@ -190,7 +189,8 @@ class Message(Base):
)
# -- Annotations ---
# ── Annotations ───────────────────────────────────────────────────────
class Annotation(Base):
__tablename__ = "annotations"
@@ -206,14 +206,19 @@ class Annotation(Base):
String(32), ForeignKey("users.id"), nullable=True
)
text: Mapped[str] = mapped_column(Text, nullable=False)
severity: Mapped[str] = mapped_column(String(16), default="info")
tag: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
severity: Mapped[str] = mapped_column(
String(16), default="info"
) # info | low | medium | high | critical
tag: Mapped[Optional[str]] = mapped_column(
String(32), nullable=True
) # suspicious | benign | needs-review
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
# relationships
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
@@ -223,7 +228,8 @@ class Annotation(Base):
)
# -- Hypotheses ---
# ── Hypotheses ────────────────────────────────────────────────────────
class Hypothesis(Base):
__tablename__ = "hypotheses"
@@ -235,7 +241,9 @@ class Hypothesis(Base):
title: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
status: Mapped[str] = mapped_column(String(16), default="draft")
status: Mapped[str] = mapped_column(
String(16), default="draft"
) # draft | active | confirmed | rejected
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
@@ -243,6 +251,7 @@ class Hypothesis(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
# relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
__table_args__ = (
@@ -250,16 +259,21 @@ class Hypothesis(Base):
)
# -- Enrichment Results ---
# ── Enrichment Results ────────────────────────────────────────────────
class EnrichmentResult(Base):
__tablename__ = "enrichment_results"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
ioc_type: Mapped[str] = mapped_column(String(32), nullable=False)
source: Mapped[str] = mapped_column(String(32), nullable=False)
verdict: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
ioc_type: Mapped[str] = mapped_column(
String(32), nullable=False
) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
verdict: Mapped[Optional[str]] = mapped_column(
String(16), nullable=True
) # clean | suspicious | malicious | unknown
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
@@ -274,24 +288,28 @@ class EnrichmentResult(Base):
)
# -- AUP Keyword Themes & Keywords ---
# ── AUP Keyword Themes & Keywords ────────────────────────────────────
class KeywordTheme(Base):
"""A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
__tablename__ = "keyword_themes"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e")
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False)
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships
keywords: Mapped[list["Keyword"]] = relationship(
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
)
class Keyword(Base):
"""Individual keyword / pattern belonging to a theme."""
__tablename__ = "keywords"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
@@ -302,6 +320,7 @@ class Keyword(Base):
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
__table_args__ = (
@@ -310,91 +329,223 @@ class Keyword(Base):
)
# -- AI Analysis Tables (Phase 2) ---
# ── Cases ─────────────────────────────────────────────────────────────
class TriageResult(Base):
__tablename__ = "triage_results"
class Case(Base):
"""Incident / investigation case, inspired by TheHive."""
__tablename__ = "cases"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
dataset_id: Mapped[str] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
title: Mapped[str] = mapped_column(String(512), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
severity: Mapped[str] = mapped_column(String(16), default="medium") # info|low|medium|high|critical
tlp: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
pap: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
status: Mapped[str] = mapped_column(String(24), default="open") # open|in-progress|resolved|closed
priority: Mapped[int] = mapped_column(Integer, default=2) # 1(urgent)..4(low)
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
row_start: Mapped[int] = mapped_column(Integer, nullable=False)
row_end: Mapped[int] = mapped_column(Integer, nullable=False)
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
verdict: Mapped[str] = mapped_column(String(20), default="pending")
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
suspicious_indicators: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
dataset: Mapped["Dataset"] = relationship(back_populates="triage_results")
class HostProfile(Base):
__tablename__ = "host_profiles"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[str] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
owner_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
hostname: Mapped[str] = mapped_column(String(256), nullable=False)
fqdn: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
client_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
risk_level: Mapped[str] = mapped_column(String(20), default="unknown")
artifact_summary: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
timeline_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
suspicious_findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
llm_analysis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
iocs: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{type, value, description}]
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
hunt: Mapped["Hunt"] = relationship(back_populates="host_profiles")
# relationships
tasks: Mapped[list["CaseTask"]] = relationship(
back_populates="case", lazy="selectin", cascade="all, delete-orphan"
)
__table_args__ = (
Index("ix_cases_hunt", "hunt_id"),
Index("ix_cases_status", "status"),
)
class HuntReport(Base):
__tablename__ = "hunt_reports"
class CaseTask(Base):
"""Task within a case (Kanban board item)."""
__tablename__ = "case_tasks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[str] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
case_id: Mapped[str] = mapped_column(
String(32), ForeignKey("cases.id", ondelete="CASCADE"), nullable=False
)
status: Mapped[str] = mapped_column(String(20), default="pending")
exec_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
full_report: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
recommendations: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
mitre_mapping: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
ioc_table: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
host_risk_summary: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
models_used: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
generation_time_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
title: Mapped[str] = mapped_column(String(512), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(24), default="todo") # todo|in-progress|done
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
order: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
hunt: Mapped["Hunt"] = relationship(back_populates="reports")
# relationships
case: Mapped["Case"] = relationship(back_populates="tasks")
__table_args__ = (
Index("ix_case_tasks_case", "case_id"),
)
class AnomalyResult(Base):
__tablename__ = "anomaly_results"
# ── Activity Log ──────────────────────────────────────────────────────
class ActivityLog(Base):
"""Audit trail / activity log for cases and hunts."""
__tablename__ = "activity_logs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
entity_type: Mapped[str] = mapped_column(String(32), nullable=False) # case|hunt|annotation
entity_id: Mapped[str] = mapped_column(String(32), nullable=False)
action: Mapped[str] = mapped_column(String(64), nullable=False) # created|updated|status_changed|etc
details: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
user_id: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
__table_args__ = (
Index("ix_activity_entity", "entity_type", "entity_id"),
)
# ── Alerts ────────────────────────────────────────────────────────────
class Alert(Base):
"""Security alert generated by analyzers or rules."""
__tablename__ = "alerts"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
dataset_id: Mapped[str] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
title: Mapped[str] = mapped_column(String(512), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
severity: Mapped[str] = mapped_column(String(16), default="medium") # critical|high|medium|low|info
status: Mapped[str] = mapped_column(String(24), default="new") # new|acknowledged|in-progress|resolved|false-positive
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # which analyzer produced it
score: Mapped[float] = mapped_column(Float, default=0.0)
evidence: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{row_index, field, value, ...}]
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
row_id: Mapped[Optional[int]] = mapped_column(
Integer, ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True
dataset_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("datasets.id"), nullable=True
)
case_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("cases.id"), nullable=True
)
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
acknowledged_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_alerts_severity", "severity"),
Index("ix_alerts_status", "status"),
Index("ix_alerts_hunt", "hunt_id"),
Index("ix_alerts_dataset", "dataset_id"),
)
class AlertRule(Base):
"""User-defined alert rule (triggers analyzers automatically on upload)."""
__tablename__ = "alert_rules"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # analyzer name
config: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # analyzer config overrides
severity_override: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
) # None = global
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_alert_rules_analyzer", "analyzer"),
)
# ── Notebooks ────────────────────────────────────────────────────────
class Notebook(Base):
"""Investigation notebook — cell-based document for analyst notes and queries."""
__tablename__ = "notebooks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
title: Mapped[str] = mapped_column(String(512), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
cells: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{id, cell_type, source, output, metadata}]
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
case_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("cases.id"), nullable=True
)
owner_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_notebooks_hunt", "hunt_id"),
)
# ── Playbook Runs ────────────────────────────────────────────────────
class PlaybookRun(Base):
"""Record of a playbook execution (links a template to a hunt/case)."""
__tablename__ = "playbook_runs"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
playbook_name: Mapped[str] = mapped_column(String(256), nullable=False)
status: Mapped[str] = mapped_column(String(24), default="in-progress") # in-progress | completed | aborted
current_step: Mapped[int] = mapped_column(Integer, default=1)
total_steps: Mapped[int] = mapped_column(Integer, default=0)
step_results: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{step, status, notes, completed_at}]
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
case_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("cases.id"), nullable=True
)
started_by: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
__table_args__ = (
Index("ix_playbook_runs_hunt", "hunt_id"),
Index("ix_playbook_runs_status", "status"),
)
<<<<<<< HEAD
anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
@@ -505,3 +656,5 @@ class SavedSearch(Base):
__table_args__ = (
Index("ix_saved_searches_type", "search_type"),
)
=======
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -1,12 +1,10 @@
"""ThreatHunt backend application.
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
annotation/hypothesis routes, analysis routes, network routes, job queue,
load balancer. DB tables are auto-created on startup.
annotation/hypothesis routes. DB tables are auto-created on startup.
"""
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
@@ -23,13 +21,19 @@ from app.api.routes.correlation import router as correlation_router
from app.api.routes.reports import router as reports_router
from app.api.routes.auth import router as auth_router
from app.api.routes.keywords import router as keywords_router
from app.api.routes.analysis import router as analysis_router
from app.api.routes.network import router as network_router
<<<<<<< HEAD
from app.api.routes.mitre import router as mitre_router
from app.api.routes.timeline import router as timeline_router
from app.api.routes.playbooks import router as playbooks_router
from app.api.routes.saved_searches import router as searches_router
from app.api.routes.stix_export import router as stix_router
=======
from app.api.routes.analysis import router as analysis_router
from app.api.routes.cases import router as cases_router
from app.api.routes.alerts import router as alerts_router
from app.api.routes.notebooks import router as notebooks_router
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
logger = logging.getLogger(__name__)
@@ -37,20 +41,16 @@ logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup / shutdown lifecycle."""
logger.info("Starting ThreatHunt API ...")
logger.info("Starting ThreatHunt API ")
await init_db()
logger.info("Database initialised")
# Ensure uploads directory exists
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
logger.info("Upload dir: %s", os.path.abspath(settings.UPLOAD_DIR))
# Seed default AUP keyword themes
from app.db import async_session_factory
from app.services.keyword_defaults import seed_defaults
async with async_session_factory() as seed_db:
await seed_defaults(seed_db)
logger.info("AUP keyword defaults checked")
<<<<<<< HEAD
# Start job queue
from app.services.job_queue import (
@@ -141,6 +141,10 @@ async def lifespan(app: FastAPI):
await _lb.stop_health_loop()
logger.info("Load balancer stopped")
=======
yield
logger.info("Shutting down …")
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
from app.agents.providers_v2 import cleanup_client
from app.services.enrichment import enrichment_engine
await cleanup_client()
@@ -148,13 +152,15 @@ async def lifespan(app: FastAPI):
await dispose_db()
# Create FastAPI application
app = FastAPI(
title="ThreatHunt API",
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
version=settings.APP_VERSION,
version="0.3.0",
lifespan=lifespan,
)
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
@@ -174,17 +180,24 @@ app.include_router(enrichment_router)
app.include_router(correlation_router)
app.include_router(reports_router)
app.include_router(keywords_router)
app.include_router(analysis_router)
app.include_router(network_router)
<<<<<<< HEAD
app.include_router(mitre_router)
app.include_router(timeline_router)
app.include_router(playbooks_router)
app.include_router(searches_router)
app.include_router(stix_router)
=======
app.include_router(analysis_router)
app.include_router(cases_router)
app.include_router(alerts_router)
app.include_router(notebooks_router)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
@app.get("/", tags=["health"])
async def root():
"""API health check."""
return {
"service": "ThreatHunt API",
"version": settings.APP_VERSION,
@@ -196,6 +209,7 @@ async def root():
"openwebui": settings.OPENWEBUI_URL,
},
}
<<<<<<< HEAD
@app.get("/health", tags=["health"])
@@ -205,3 +219,5 @@ async def health():
"version": settings.APP_VERSION,
"status": "ok",
}
=======
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -0,0 +1,464 @@
"""Pluggable Analyzer Framework for ThreatHunt.
Each analyzer implements a simple protocol:
- name / description properties
- async analyze(rows, config) -> list[AlertCandidate]
The AnalyzerRegistry discovers and runs all enabled analyzers against
a dataset, producing alert candidates that the alert system can persist.
"""
from __future__ import annotations
import logging
import math
import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence
logger = logging.getLogger(__name__)
# ── Alert Candidate DTO ──────────────────────────────────────────────
@dataclass
class AlertCandidate:
"""A single finding from an analyzer, before it becomes a persisted Alert."""
analyzer: str
title: str
severity: str # critical | high | medium | low | info
description: str
evidence: list[dict] = field(default_factory=list) # [{row_index, field, value, ...}]
mitre_technique: Optional[str] = None
tags: list[str] = field(default_factory=list)
score: float = 0.0 # 0-100
# ── Base Analyzer ────────────────────────────────────────────────────
class BaseAnalyzer(ABC):
"""Interface every analyzer must implement."""
@property
@abstractmethod
def name(self) -> str: ...
@property
@abstractmethod
def description(self) -> str: ...
@abstractmethod
async def analyze(
self, rows: list[dict[str, Any]], config: dict[str, Any] | None = None
) -> list[AlertCandidate]: ...
# ── Built-in Analyzers ──────────────────────────────────────────────
class EntropyAnalyzer(BaseAnalyzer):
"""Detects high-entropy strings (encoded payloads, obfuscated commands)."""
name = "entropy"
description = "Flags fields with high Shannon entropy (possible encoding/obfuscation)"
ENTROPY_FIELDS = [
"command_line", "commandline", "process_command_line", "cmdline",
"powershell_command", "script_block", "url", "uri", "path",
"file_path", "target_filename", "query", "dns_query",
]
DEFAULT_THRESHOLD = 4.5
@staticmethod
def _shannon(s: str) -> float:
if not s or len(s) < 8:
return 0.0
freq = Counter(s)
length = len(s)
return -sum((c / length) * math.log2(c / length) for c in freq.values())
async def analyze(self, rows, config=None):
config = config or {}
threshold = config.get("entropy_threshold", self.DEFAULT_THRESHOLD)
min_length = config.get("min_length", 20)
alerts: list[AlertCandidate] = []
for idx, row in enumerate(rows):
for field_name in self.ENTROPY_FIELDS:
val = str(row.get(field_name, ""))
if len(val) < min_length:
continue
ent = self._shannon(val)
if ent >= threshold:
sev = "critical" if ent > 5.5 else "high" if ent > 5.0 else "medium"
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"High-entropy string in {field_name}",
severity=sev,
description=f"Shannon entropy {ent:.2f} (threshold {threshold}) in row {idx}, field '{field_name}'",
evidence=[{"row_index": idx, "field": field_name, "value": val[:200], "entropy": round(ent, 3)}],
mitre_technique="T1027", # Obfuscated Files or Information
tags=["obfuscation", "entropy"],
score=min(100, ent * 18),
))
return alerts
class SuspiciousCommandAnalyzer(BaseAnalyzer):
"""Detects known-bad command patterns (credential dumping, lateral movement, persistence)."""
name = "suspicious_commands"
description = "Flags processes executing known-suspicious command patterns"
PATTERNS: list[tuple[str, str, str, str]] = [
# (regex, title, severity, mitre_technique)
(r"mimikatz|sekurlsa|lsadump|kerberos::list", "Mimikatz / Credential Dumping", "critical", "T1003"),
(r"(?i)-enc\s+[A-Za-z0-9+/=]{40,}", "Encoded PowerShell command", "high", "T1059.001"),
(r"(?i)invoke-(mimikatz|expression|webrequest|shellcode)", "Suspicious PowerShell Invoke", "high", "T1059.001"),
(r"(?i)net\s+(user|localgroup|group)\s+/add", "Local account creation", "high", "T1136.001"),
(r"(?i)schtasks\s+/create", "Scheduled task creation", "medium", "T1053.005"),
(r"(?i)reg\s+add\s+.*\\run", "Registry Run key persistence", "high", "T1547.001"),
(r"(?i)wmic\s+.*(process\s+call|shadowcopy\s+delete)", "WMI abuse / shadow copy deletion", "critical", "T1047"),
(r"(?i)psexec|winrm|wmic\s+/node:", "Lateral movement tool", "high", "T1021"),
(r"(?i)certutil\s+-urlcache", "Certutil download (LOLBin)", "high", "T1105"),
(r"(?i)bitsadmin\s+/transfer", "BITSAdmin download", "medium", "T1197"),
(r"(?i)vssadmin\s+delete\s+shadows", "VSS shadow deletion (ransomware)", "critical", "T1490"),
(r"(?i)bcdedit.*recoveryenabled.*no", "Boot config tamper (ransomware)", "critical", "T1490"),
(r"(?i)attrib\s+\+h\s+\+s", "Hidden file attribute set", "low", "T1564.001"),
(r"(?i)netsh\s+advfirewall\s+.*disable", "Firewall disabled", "high", "T1562.004"),
(r"(?i)whoami\s*/priv", "Privilege enumeration", "medium", "T1033"),
(r"(?i)nltest\s+/dclist", "Domain controller enumeration", "medium", "T1018"),
(r"(?i)dsquery|ldapsearch|adfind", "Active Directory enumeration", "medium", "T1087.002"),
(r"(?i)procdump.*-ma\s+lsass", "LSASS memory dump", "critical", "T1003.001"),
(r"(?i)rundll32.*comsvcs.*MiniDump", "LSASS dump via comsvcs", "critical", "T1003.001"),
]
CMD_FIELDS = [
"command_line", "commandline", "process_command_line", "cmdline",
"parent_command_line", "powershell_command",
]
async def analyze(self, rows, config=None):
alerts: list[AlertCandidate] = []
compiled = [(re.compile(p, re.IGNORECASE), t, s, m) for p, t, s, m in self.PATTERNS]
for idx, row in enumerate(rows):
for fld in self.CMD_FIELDS:
val = str(row.get(fld, ""))
if len(val) < 3:
continue
for pattern, title, sev, mitre in compiled:
if pattern.search(val):
alerts.append(AlertCandidate(
analyzer=self.name,
title=title,
severity=sev,
description=f"Suspicious command pattern in row {idx}: {val[:200]}",
evidence=[{"row_index": idx, "field": fld, "value": val[:300]}],
mitre_technique=mitre,
tags=["command", "suspicious"],
score={"critical": 95, "high": 80, "medium": 60, "low": 30}.get(sev, 50),
))
return alerts
class NetworkAnomalyAnalyzer(BaseAnalyzer):
"""Detects anomalous network patterns (beaconing, unusual ports, large transfers)."""
name = "network_anomaly"
description = "Flags anomalous network behavior (beaconing, unusual ports, large transfers)"
SUSPICIOUS_PORTS = {4444, 5555, 6666, 8888, 9999, 1234, 31337, 12345, 54321, 1337}
C2_PORTS = {443, 8443, 8080, 4443, 9443}
async def analyze(self, rows, config=None):
config = config or {}
alerts: list[AlertCandidate] = []
# Track destination IP frequency for beaconing detection
dst_freq: dict[str, list[int]] = defaultdict(list)
port_hits: list[tuple[int, str, int]] = []
for idx, row in enumerate(rows):
dst_ip = str(row.get("dst_ip", row.get("destination_ip", row.get("dest_ip", ""))))
dst_port = row.get("dst_port", row.get("destination_port", row.get("dest_port", "")))
if dst_ip and dst_ip != "":
dst_freq[dst_ip].append(idx)
if dst_port:
try:
port_num = int(dst_port)
if port_num in self.SUSPICIOUS_PORTS:
port_hits.append((idx, dst_ip, port_num))
except (ValueError, TypeError):
pass
# Large transfer detection
bytes_val = row.get("bytes_sent", row.get("bytes_out", row.get("sent_bytes", 0)))
try:
if int(bytes_val or 0) > config.get("large_transfer_threshold", 10_000_000):
alerts.append(AlertCandidate(
analyzer=self.name,
title="Large data transfer detected",
severity="medium",
description=f"Row {idx}: {bytes_val} bytes sent to {dst_ip}",
evidence=[{"row_index": idx, "dst_ip": dst_ip, "bytes": str(bytes_val)}],
mitre_technique="T1048",
tags=["exfiltration", "network"],
score=65,
))
except (ValueError, TypeError):
pass
# Beaconing: IPs contacted more than threshold times
beacon_thresh = config.get("beacon_threshold", 20)
for ip, indices in dst_freq.items():
if len(indices) >= beacon_thresh:
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"Possible beaconing to {ip}",
severity="high",
description=f"Destination {ip} contacted {len(indices)} times (threshold: {beacon_thresh})",
evidence=[{"dst_ip": ip, "contact_count": len(indices), "sample_rows": indices[:10]}],
mitre_technique="T1071",
tags=["beaconing", "c2", "network"],
score=min(95, 50 + len(indices)),
))
# Suspicious ports
for idx, ip, port in port_hits:
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"Connection on suspicious port {port}",
severity="medium",
description=f"Row {idx}: connection to {ip}:{port}",
evidence=[{"row_index": idx, "dst_ip": ip, "dst_port": port}],
mitre_technique="T1571",
tags=["suspicious_port", "network"],
score=55,
))
return alerts
class FrequencyAnomalyAnalyzer(BaseAnalyzer):
"""Detects statistically rare values that may indicate anomalies."""
name = "frequency_anomaly"
description = "Flags statistically rare field values (potential anomalies)"
FIELDS_TO_CHECK = [
"process_name", "image_name", "parent_process_name",
"user", "username", "user_name",
"event_type", "action", "status",
]
async def analyze(self, rows, config=None):
config = config or {}
rarity_threshold = config.get("rarity_threshold", 0.01) # <1% occurrence
min_rows = config.get("min_rows", 50)
alerts: list[AlertCandidate] = []
if len(rows) < min_rows:
return alerts
for fld in self.FIELDS_TO_CHECK:
values = [str(row.get(fld, "")) for row in rows if row.get(fld)]
if not values:
continue
counts = Counter(values)
total = len(values)
for val, cnt in counts.items():
pct = cnt / total
if pct <= rarity_threshold and cnt <= 3:
# Find row indices
indices = [i for i, r in enumerate(rows) if str(r.get(fld, "")) == val]
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"Rare {fld}: {val[:80]}",
severity="low",
description=f"'{val}' appears {cnt}/{total} times ({pct:.2%}) in field '{fld}'",
evidence=[{"field": fld, "value": val[:200], "count": cnt, "total": total, "rows": indices[:5]}],
tags=["anomaly", "rare"],
score=max(20, 50 - (pct * 5000)),
))
return alerts
class AuthAnomalyAnalyzer(BaseAnalyzer):
"""Detects authentication anomalies (brute force, unusual logon types)."""
name = "auth_anomaly"
description = "Flags authentication anomalies (failed logins, unusual logon types)"
async def analyze(self, rows, config=None):
config = config or {}
alerts: list[AlertCandidate] = []
# Track failed logins per user
failed_by_user: dict[str, list[int]] = defaultdict(list)
logon_types: dict[str, list[int]] = defaultdict(list)
for idx, row in enumerate(rows):
event_type = str(row.get("event_type", row.get("action", ""))).lower()
status = str(row.get("status", row.get("result", ""))).lower()
user = str(row.get("username", row.get("user", row.get("user_name", ""))))
logon_type = str(row.get("logon_type", ""))
if "logon" in event_type or "auth" in event_type or "login" in event_type:
if "fail" in status or "4625" in str(row.get("event_id", "")):
if user:
failed_by_user[user].append(idx)
if logon_type in ("3", "10"): # Network/RemoteInteractive
logon_types[logon_type].append(idx)
# Brute force: >5 failed logins for same user
brute_thresh = config.get("brute_force_threshold", 5)
for user, indices in failed_by_user.items():
if len(indices) >= brute_thresh:
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"Possible brute force: {user}",
severity="high",
description=f"User '{user}' had {len(indices)} failed logins",
evidence=[{"user": user, "failed_count": len(indices), "rows": indices[:10]}],
mitre_technique="T1110",
tags=["brute_force", "authentication"],
score=min(90, 50 + len(indices) * 3),
))
# Unusual logon types
for ltype, indices in logon_types.items():
label = "Network logon (Type 3)" if ltype == "3" else "Remote Desktop (Type 10)"
if len(indices) >= 3:
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"{label} detected",
severity="medium" if ltype == "3" else "high",
description=f"{len(indices)} {label} events detected",
evidence=[{"logon_type": ltype, "count": len(indices), "rows": indices[:10]}],
mitre_technique="T1021",
tags=["authentication", "lateral_movement"],
score=55 if ltype == "3" else 70,
))
return alerts
class PersistenceAnalyzer(BaseAnalyzer):
"""Detects persistence mechanisms (registry keys, services, scheduled tasks)."""
name = "persistence"
description = "Flags persistence mechanism installations"
REGISTRY_PATTERNS = [
(r"(?i)\\CurrentVersion\\Run", "Run key persistence", "T1547.001"),
(r"(?i)\\Services\\", "Service installation", "T1543.003"),
(r"(?i)\\Winlogon\\", "Winlogon persistence", "T1547.004"),
(r"(?i)\\Image File Execution Options\\", "IFEO debugger persistence", "T1546.012"),
(r"(?i)\\Explorer\\Shell Folders", "Shell folder hijack", "T1547.001"),
]
async def analyze(self, rows, config=None):
alerts: list[AlertCandidate] = []
compiled = [(re.compile(p), t, m) for p, t, m in self.REGISTRY_PATTERNS]
for idx, row in enumerate(rows):
# Check registry paths
reg_path = str(row.get("registry_key", row.get("target_object", row.get("registry_path", ""))))
for pattern, title, mitre in compiled:
if pattern.search(reg_path):
alerts.append(AlertCandidate(
analyzer=self.name,
title=title,
severity="high",
description=f"Row {idx}: {reg_path[:200]}",
evidence=[{"row_index": idx, "registry_key": reg_path[:300]}],
mitre_technique=mitre,
tags=["persistence", "registry"],
score=75,
))
# Check for service creation events
event_type = str(row.get("event_type", "")).lower()
if "service" in event_type and "creat" in event_type:
svc_name = row.get("service_name", row.get("target_filename", "unknown"))
alerts.append(AlertCandidate(
analyzer=self.name,
title=f"Service created: {svc_name}",
severity="medium",
description=f"Row {idx}: New service '{svc_name}' created",
evidence=[{"row_index": idx, "service_name": str(svc_name)}],
mitre_technique="T1543.003",
tags=["persistence", "service"],
score=60,
))
return alerts
# ── Analyzer Registry ────────────────────────────────────────────────
_ALL_ANALYZERS: list[BaseAnalyzer] = [
EntropyAnalyzer(),
SuspiciousCommandAnalyzer(),
NetworkAnomalyAnalyzer(),
FrequencyAnomalyAnalyzer(),
AuthAnomalyAnalyzer(),
PersistenceAnalyzer(),
]
def get_available_analyzers() -> list[dict[str, str]]:
"""Return metadata about all registered analyzers."""
return [{"name": a.name, "description": a.description} for a in _ALL_ANALYZERS]
def get_analyzer(name: str) -> BaseAnalyzer | None:
"""Get an analyzer by name."""
for a in _ALL_ANALYZERS:
if a.name == name:
return a
return None
async def run_all_analyzers(
rows: list[dict[str, Any]],
enabled: list[str] | None = None,
config: dict[str, Any] | None = None,
) -> list[AlertCandidate]:
"""Run all (or selected) analyzers and return combined alert candidates.
Args:
rows: Flat list of row dicts (normalized_data or data from DatasetRow).
enabled: Optional list of analyzer names to run. Runs all if None.
config: Optional config overrides passed to each analyzer.
Returns:
Combined list of AlertCandidate from all analyzers, sorted by score desc.
"""
config = config or {}
results: list[AlertCandidate] = []
for analyzer in _ALL_ANALYZERS:
if enabled and analyzer.name not in enabled:
continue
try:
candidates = await analyzer.analyze(rows, config)
results.extend(candidates)
logger.info("Analyzer %s produced %d alerts", analyzer.name, len(candidates))
except Exception:
logger.exception("Analyzer %s failed", analyzer.name)
# Sort by score descending
results.sort(key=lambda a: a.score, reverse=True)
return results

View File

@@ -0,0 +1,322 @@
"""LLM-powered dataset analysis — replaces manual IOC enrichment.
Loads dataset rows server-side, builds a concise summary, and sends it
to Wile (70B heavy) or Roadrunner (fast) for threat analysis.
Supports both single-dataset and hunt-wide analysis.
"""
import asyncio
import json
import logging
import time
from collections import Counter, defaultdict
from typing import Any, Optional
from pydantic import BaseModel, Field
from app.config import settings
from app.agents.providers_v2 import OllamaProvider
from app.agents.router import TaskType, task_router
from app.services.sans_rag import sans_rag
logger = logging.getLogger(__name__)
# ── Request / Response models ─────────────────────────────────────────
class AnalysisRequest(BaseModel):
"""Request for LLM-powered analysis of a dataset."""
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
question: str = Field(
default="Perform a comprehensive threat analysis of this dataset. "
"Identify anomalies, suspicious patterns, potential IOCs, and recommend "
"next steps for the analyst.",
description="Specific question or general analysis request",
)
mode: str = Field(default="deep", description="quick | deep")
focus: Optional[str] = Field(
None,
description="Focus area: threats, anomalies, lateral_movement, exfil, persistence, recon",
)
class AnalysisResult(BaseModel):
"""LLM analysis result."""
analysis: str = Field(..., description="Full analysis text (markdown)")
confidence: float = Field(default=0.0, description="0-1 confidence")
key_findings: list[str] = Field(default_factory=list)
iocs_identified: list[dict] = Field(default_factory=list)
recommended_actions: list[str] = Field(default_factory=list)
mitre_techniques: list[str] = Field(default_factory=list)
risk_score: int = Field(default=0, description="0-100 risk score")
model_used: str = ""
node_used: str = ""
latency_ms: int = 0
rows_analyzed: int = 0
dataset_summary: str = ""
# ── Analysis prompts ──────────────────────────────────────────────────
ANALYSIS_SYSTEM = """You are an expert threat hunter and incident response analyst.
You are analyzing CSV log data from forensic tools (Velociraptor, Sysmon, etc.).
Your task: Perform deep threat analysis of the data provided and produce actionable findings.
RESPOND WITH VALID JSON ONLY:
{
"analysis": "Detailed markdown analysis with headers and bullet points",
"confidence": 0.85,
"key_findings": ["Finding 1", "Finding 2"],
"iocs_identified": [{"type": "ip", "value": "1.2.3.4", "context": "C2 traffic"}],
"recommended_actions": ["Action 1", "Action 2"],
"mitre_techniques": ["T1059.001 - PowerShell", "T1071 - Application Layer Protocol"],
"risk_score": 65
}
"""
FOCUS_PROMPTS = {
"threats": "Focus on identifying active threats, malware indicators, and attack patterns.",
"anomalies": "Focus on statistical anomalies, outliers, and unusual behavior patterns.",
"lateral_movement": "Focus on evidence of lateral movement: PsExec, WMI, RDP, SMB, pass-the-hash.",
"exfil": "Focus on data exfiltration indicators: large transfers, DNS tunneling, unusual destinations.",
"persistence": "Focus on persistence mechanisms: scheduled tasks, services, registry, startup items.",
"recon": "Focus on reconnaissance activity: scanning, enumeration, discovery commands.",
}
# ── Data summarizer ───────────────────────────────────────────────────
def summarize_dataset_rows(
rows: list[dict],
columns: list[str] | None = None,
max_sample: int = 20,
max_chars: int = 6000,
) -> str:
"""Build a concise text summary of dataset rows for LLM consumption.
Includes:
- Column headers and types
- Statistical summary (unique values, top values per column)
- Sample rows (first N)
- Detected patterns of interest
"""
if not rows:
return "Empty dataset — no rows to analyze."
cols = columns or list(rows[0].keys())
n_rows = len(rows)
parts: list[str] = []
parts.append(f"## Dataset Summary: {n_rows} rows, {len(cols)} columns")
parts.append(f"Columns: {', '.join(cols)}")
# Per-column stats
parts.append("\n### Column Statistics:")
for col in cols[:30]: # limit to first 30 cols
values = [str(r.get(col, "")) for r in rows if r.get(col) not in (None, "", "N/A")]
if not values:
continue
unique = len(set(values))
counter = Counter(values)
top3 = counter.most_common(3)
top_str = ", ".join(f"{v} ({c}x)" for v, c in top3)
parts.append(f"- **{col}**: {len(values)} non-null, {unique} unique. Top: {top_str}")
# Sample rows
sample = rows[:max_sample]
parts.append(f"\n### Sample Rows (first {len(sample)}):")
for i, row in enumerate(sample):
row_str = " | ".join(f"{k}={v}" for k, v in row.items() if v not in (None, "", "N/A"))
parts.append(f"{i+1}. {row_str}")
# Detect interesting patterns
patterns: list[str] = []
all_cmds = [str(r.get("command_line", "")).lower() for r in rows if r.get("command_line")]
sus_cmds = [c for c in all_cmds if any(
k in c for k in ["powershell -enc", "certutil", "bitsadmin", "mshta",
"regsvr32", "invoke-", "mimikatz", "psexec"]
)]
if sus_cmds:
patterns.append(f"⚠️ {len(sus_cmds)} suspicious command lines detected")
all_ips = [str(r.get("dst_ip", "")) for r in rows if r.get("dst_ip")]
ext_ips = [ip for ip in all_ips if ip and not ip.startswith(("10.", "192.168.", "172.", "127."))]
if ext_ips:
unique_ext = len(set(ext_ips))
patterns.append(f"🌐 {unique_ext} unique external destination IPs")
if patterns:
parts.append("\n### Detected Patterns:")
for p in patterns:
parts.append(f"- {p}")
text = "\n".join(parts)
if len(text) > max_chars:
text = text[:max_chars] + "\n... (truncated)"
return text
# ── LLM analysis engine ──────────────────────────────────────────────
async def run_llm_analysis(
rows: list[dict],
request: AnalysisRequest,
dataset_name: str = "unknown",
) -> AnalysisResult:
"""Run LLM analysis on dataset rows."""
start = time.monotonic()
# Build summary
summary = summarize_dataset_rows(rows)
# Route to appropriate model
task_type = TaskType.DEEP_ANALYSIS if request.mode == "deep" else TaskType.QUICK_CHAT
decision = task_router.route(task_type)
# Build prompt
focus_text = FOCUS_PROMPTS.get(request.focus or "", "")
prompt = f"""Analyze the following forensic dataset from '{dataset_name}'.
{focus_text}
Analyst question: {request.question}
{summary}
"""
# Enrich with SANS RAG
try:
rag_context = await sans_rag.enrich_prompt(
request.question,
investigation_context=f"Analyzing {len(rows)} rows from {dataset_name}",
)
if rag_context:
prompt = f"{prompt}\n\n{rag_context}"
except Exception as e:
logger.warning(f"SANS RAG enrichment failed: {e}")
# Call LLM
provider = task_router.get_provider(decision)
messages = [
{"role": "system", "content": ANALYSIS_SYSTEM},
{"role": "user", "content": prompt},
]
try:
raw = await asyncio.wait_for(
provider.generate(
prompt=prompt,
system=ANALYSIS_SYSTEM,
max_tokens=settings.AGENT_MAX_TOKENS * 2, # longer for analysis
temperature=0.3,
),
timeout=300, # 5 min hard limit
)
except asyncio.TimeoutError:
logger.error("LLM analysis timed out after 300s")
return AnalysisResult(
analysis="Analysis timed out after 5 minutes. Try a smaller dataset or 'quick' mode.",
model_used=decision.model,
node_used=decision.node,
latency_ms=int((time.monotonic() - start) * 1000),
rows_analyzed=len(rows),
dataset_summary=summary,
)
except Exception as e:
logger.error(f"LLM analysis failed: {e}")
return AnalysisResult(
analysis=f"Analysis failed: {str(e)}",
model_used=decision.model,
node_used=decision.node,
latency_ms=int((time.monotonic() - start) * 1000),
rows_analyzed=len(rows),
dataset_summary=summary,
)
elapsed = int((time.monotonic() - start) * 1000)
# Parse JSON response
result = _parse_analysis(raw)
result.model_used = decision.model
result.node_used = decision.node
result.latency_ms = elapsed
result.rows_analyzed = len(rows)
result.dataset_summary = summary
return result
def _parse_analysis(raw) -> AnalysisResult:
"""Try to parse LLM output as JSON, fall back to plain text.
raw may be:
- A dict from OllamaProvider.generate() with key "response" containing LLM text
- A plain string from other providers
"""
# Ollama provider returns {"response": "<llm text>", "model": ..., ...}
if isinstance(raw, dict):
text = raw.get("response") or raw.get("analysis") or str(raw)
logger.info(f"_parse_analysis: extracted text from dict, len={len(text)}, first 200 chars: {text[:200]}")
else:
text = str(raw)
logger.info(f"_parse_analysis: raw is str, len={len(text)}, first 200 chars: {text[:200]}")
text = text.strip()
# Strip markdown code fences
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines).strip()
# Try direct JSON parse first
for candidate in _extract_json_candidates(text):
try:
data = json.loads(candidate)
if isinstance(data, dict):
logger.info(f"_parse_analysis: parsed JSON OK, keys={list(data.keys())}")
return AnalysisResult(
analysis=data.get("analysis", text),
confidence=float(data.get("confidence", 0.5)),
key_findings=data.get("key_findings", []),
iocs_identified=data.get("iocs_identified", []),
recommended_actions=data.get("recommended_actions", []),
mitre_techniques=data.get("mitre_techniques", []),
risk_score=int(data.get("risk_score", 0)),
)
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"_parse_analysis: JSON parse failed: {e}, candidate len={len(candidate)}, first 100: {candidate[:100]}")
continue
# Fallback: plain text
logger.warning(f"_parse_analysis: all JSON parse attempts failed, falling back to plain text")
return AnalysisResult(
analysis=text,
confidence=0.5,
)
def _extract_json_candidates(text: str):
"""Yield JSON candidate strings from text, trying progressively more aggressive extraction."""
import re
# 1. The whole text as-is
yield text
# 2. Find outermost { ... } block
start = text.find("{")
end = text.rfind("}")
if start != -1 and end > start:
block = text[start:end + 1]
yield block
# 3. Try to fix common LLM JSON issues:
# - trailing commas before ] or }
fixed = re.sub(r',\s*([}\]])', r'\1', block)
if fixed != block:
yield fixed

View File

@@ -0,0 +1,484 @@
"""
MITRE ATT&CK mapping service.
Maps dataset events to ATT&CK techniques using pattern-based heuristics.
Uses the enterprise-attack matrix (embedded patterns for offline use).
"""
import logging
import re
from collections import Counter, defaultdict
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow
logger = logging.getLogger(__name__)
# ── ATT&CK Technique Patterns ────────────────────────────────────────
# Subset of enterprise-attack techniques with detection patterns.
# Each entry: (technique_id, name, tactic, patterns_list)
TECHNIQUE_PATTERNS: list[tuple[str, str, str, list[str]]] = [
# Initial Access
("T1566", "Phishing", "initial-access", [
r"phish", r"\.hta\b", r"\.lnk\b", r"mshta\.exe", r"outlook.*attachment",
]),
("T1190", "Exploit Public-Facing Application", "initial-access", [
r"exploit", r"CVE-\d{4}", r"vulnerability", r"webshell",
]),
# Execution
("T1059.001", "PowerShell", "execution", [
r"powershell", r"pwsh", r"-enc\b", r"-encodedcommand",
r"invoke-expression", r"iex\b", r"bypass\b.*execution",
]),
("T1059.003", "Windows Command Shell", "execution", [
r"cmd\.exe", r"/c\s+", r"command\.com",
]),
("T1059.005", "Visual Basic", "execution", [
r"wscript", r"cscript", r"\.vbs\b", r"\.vbe\b",
]),
("T1047", "Windows Management Instrumentation", "execution", [
r"wmic\b", r"winmgmt", r"wmi\b",
]),
("T1053.005", "Scheduled Task", "execution", [
r"schtasks", r"at\.exe", r"taskschd",
]),
("T1204", "User Execution", "execution", [
r"user.*click", r"open.*attachment", r"macro",
]),
# Persistence
("T1547.001", "Registry Run Keys", "persistence", [
r"CurrentVersion\\Run", r"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run",
r"reg\s+add.*\\Run",
]),
("T1543.003", "Windows Service", "persistence", [
r"sc\s+create", r"new-service", r"service.*install",
]),
("T1136", "Create Account", "persistence", [
r"net\s+user\s+/add", r"new-localuser", r"useradd",
]),
("T1053.005", "Scheduled Task/Job", "persistence", [
r"schtasks\s+/create", r"crontab",
]),
# Privilege Escalation
("T1548.002", "Bypass User Access Control", "privilege-escalation", [
r"eventvwr", r"fodhelper", r"uac.*bypass", r"computerdefaults",
]),
("T1134", "Access Token Manipulation", "privilege-escalation", [
r"token.*impersonat", r"runas", r"adjusttokenprivileges",
]),
# Defense Evasion
("T1070.001", "Clear Windows Event Logs", "defense-evasion", [
r"wevtutil\s+cl", r"clear-eventlog", r"clearlog",
]),
("T1562.001", "Disable or Modify Tools", "defense-evasion", [
r"tamper.*protection", r"disable.*defender", r"set-mppreference",
r"disable.*firewall",
]),
("T1027", "Obfuscated Files or Information", "defense-evasion", [
r"base64", r"-enc\b", r"certutil.*-decode", r"frombase64",
]),
("T1036", "Masquerading", "defense-evasion", [
r"rename.*\.exe", r"masquerad", r"svchost.*unusual",
]),
("T1055", "Process Injection", "defense-evasion", [
r"inject", r"createremotethread", r"ntcreatethreadex",
r"virtualalloc", r"writeprocessmemory",
]),
# Credential Access
("T1003.001", "LSASS Memory", "credential-access", [
r"mimikatz", r"sekurlsa", r"lsass", r"procdump.*lsass",
]),
("T1003.003", "NTDS", "credential-access", [
r"ntds\.dit", r"vssadmin.*shadow", r"ntdsutil",
]),
("T1110", "Brute Force", "credential-access", [
r"brute.*force", r"failed.*login.*\d{3,}", r"hydra", r"medusa",
]),
("T1558.003", "Kerberoasting", "credential-access", [
r"kerberoast", r"invoke-kerberoast", r"GetUserSPNs",
]),
# Discovery
("T1087", "Account Discovery", "discovery", [
r"net\s+user", r"net\s+localgroup", r"get-aduser",
]),
("T1082", "System Information Discovery", "discovery", [
r"systeminfo", r"hostname", r"ver\b",
]),
("T1083", "File and Directory Discovery", "discovery", [
r"dir\s+/s", r"tree\s+/f", r"get-childitem.*-recurse",
]),
("T1057", "Process Discovery", "discovery", [
r"tasklist", r"get-process", r"ps\s+aux",
]),
("T1018", "Remote System Discovery", "discovery", [
r"net\s+view", r"ping\s+-", r"arp\s+-a", r"nslookup",
]),
("T1016", "System Network Configuration Discovery", "discovery", [
r"ipconfig", r"ifconfig", r"netstat",
]),
# Lateral Movement
("T1021.001", "Remote Desktop Protocol", "lateral-movement", [
r"rdp\b", r"mstsc", r"3389", r"remote\s+desktop",
]),
("T1021.002", "SMB/Windows Admin Shares", "lateral-movement", [
r"\\\\.*\\(c|admin)\$", r"psexec", r"smbclient", r"net\s+use",
]),
("T1021.006", "Windows Remote Management", "lateral-movement", [
r"winrm", r"enter-pssession", r"invoke-command.*-computername",
r"wsman", r"5985|5986",
]),
("T1570", "Lateral Tool Transfer", "lateral-movement", [
r"copy.*\\\\", r"xcopy.*\\\\", r"robocopy",
]),
# Collection
("T1560", "Archive Collected Data", "collection", [
r"compress-archive", r"7z\.exe", r"rar\s+a", r"tar\s+-[cz]",
]),
("T1005", "Data from Local System", "collection", [
r"type\s+.*password", r"findstr.*password", r"select-string.*credential",
]),
# Command and Control
("T1071.001", "Web Protocols", "command-and-control", [
r"http[s]?://\d+\.\d+\.\d+\.\d+", r"curl\b", r"wget\b",
r"invoke-webrequest", r"beacon",
]),
("T1573", "Encrypted Channel", "command-and-control", [
r"ssl\b", r"tls\b", r"encrypted.*tunnel", r"stunnel",
]),
("T1105", "Ingress Tool Transfer", "command-and-control", [
r"certutil.*-urlcache", r"bitsadmin.*transfer",
r"downloadfile", r"invoke-webrequest.*-outfile",
]),
("T1219", "Remote Access Software", "command-and-control", [
r"teamviewer", r"anydesk", r"logmein", r"vnc",
]),
# Exfiltration
("T1048", "Exfiltration Over Alternative Protocol", "exfiltration", [
r"dns.*tunnel", r"exfil", r"icmp.*tunnel",
]),
("T1041", "Exfiltration Over C2 Channel", "exfiltration", [
r"upload.*c2", r"exfil.*http",
]),
("T1567", "Exfiltration Over Web Service", "exfiltration", [
r"mega\.nz", r"dropbox", r"pastebin", r"transfer\.sh",
]),
# Impact
("T1486", "Data Encrypted for Impact", "impact", [
r"ransomware", r"encrypt.*files", r"\.locked\b", r"ransom",
]),
("T1489", "Service Stop", "impact", [
r"sc\s+stop", r"net\s+stop", r"stop-service",
]),
("T1529", "System Shutdown/Reboot", "impact", [
r"shutdown\s+/[rs]", r"restart-computer",
]),
]
# Tactic display names and kill-chain order
TACTIC_ORDER = [
"initial-access", "execution", "persistence", "privilege-escalation",
"defense-evasion", "credential-access", "discovery", "lateral-movement",
"collection", "command-and-control", "exfiltration", "impact",
]
TACTIC_NAMES = {
"initial-access": "Initial Access",
"execution": "Execution",
"persistence": "Persistence",
"privilege-escalation": "Privilege Escalation",
"defense-evasion": "Defense Evasion",
"credential-access": "Credential Access",
"discovery": "Discovery",
"lateral-movement": "Lateral Movement",
"collection": "Collection",
"command-and-control": "Command and Control",
"exfiltration": "Exfiltration",
"impact": "Impact",
}
# ── Row fetcher ───────────────────────────────────────────────────────
async def _fetch_rows(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
limit: int = 5000,
) -> list[dict[str, Any]]:
q = select(DatasetRow).join(Dataset)
if dataset_id:
q = q.where(DatasetRow.dataset_id == dataset_id)
elif hunt_id:
q = q.where(Dataset.hunt_id == hunt_id)
q = q.limit(limit)
result = await db.execute(q)
return [r.data for r in result.scalars().all()]
# ── Main functions ────────────────────────────────────────────────────
async def map_to_attack(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
) -> dict[str, Any]:
"""
Map dataset rows to MITRE ATT&CK techniques.
Returns a matrix-style structure + evidence list.
"""
rows = await _fetch_rows(db, dataset_id, hunt_id)
if not rows:
return {"tactics": [], "techniques": [], "evidence": [], "coverage": {}, "total_rows": 0}
# Flatten all string values per row for matching
row_texts: list[str] = []
for row in rows:
parts = []
for v in row.values():
if v is not None:
parts.append(str(v).lower())
row_texts.append(" ".join(parts))
# Match techniques
technique_hits: dict[str, list[dict]] = defaultdict(list) # tech_id -> evidence rows
technique_meta: dict[str, tuple[str, str]] = {} # tech_id -> (name, tactic)
row_techniques: list[set[str]] = [set() for _ in rows]
for tech_id, tech_name, tactic, patterns in TECHNIQUE_PATTERNS:
compiled = [re.compile(p, re.IGNORECASE) for p in patterns]
technique_meta[tech_id] = (tech_name, tactic)
for i, text in enumerate(row_texts):
for pat in compiled:
if pat.search(text):
row_techniques[i].add(tech_id)
if len(technique_hits[tech_id]) < 10: # limit evidence
# find matching field
matched_field = ""
matched_value = ""
for k, v in rows[i].items():
if v and pat.search(str(v).lower()):
matched_field = k
matched_value = str(v)[:200]
break
technique_hits[tech_id].append({
"row_index": i,
"field": matched_field,
"value": matched_value,
"pattern": pat.pattern,
})
break # one pattern match per technique per row is enough
# Build tactic → technique structure
tactic_techniques: dict[str, list[dict]] = defaultdict(list)
for tech_id, evidence_list in technique_hits.items():
name, tactic = technique_meta[tech_id]
tactic_techniques[tactic].append({
"id": tech_id,
"name": name,
"count": len(evidence_list),
"evidence": evidence_list[:5],
})
# Build ordered tactics list
tactics = []
for tactic_key in TACTIC_ORDER:
techs = tactic_techniques.get(tactic_key, [])
tactics.append({
"id": tactic_key,
"name": TACTIC_NAMES.get(tactic_key, tactic_key),
"techniques": sorted(techs, key=lambda t: -t["count"]),
"total_hits": sum(t["count"] for t in techs),
})
# Coverage stats
covered_tactics = sum(1 for t in tactics if t["total_hits"] > 0)
total_technique_hits = sum(t["total_hits"] for t in tactics)
return {
"tactics": tactics,
"coverage": {
"tactics_covered": covered_tactics,
"tactics_total": len(TACTIC_ORDER),
"techniques_matched": len(technique_hits),
"total_evidence": total_technique_hits,
},
"total_rows": len(rows),
}
async def build_knowledge_graph(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
) -> dict[str, Any]:
"""
Build a knowledge graph connecting entities (hosts, users, processes, IPs)
to MITRE techniques and tactics.
Returns Cytoscape-compatible nodes + edges.
"""
rows = await _fetch_rows(db, dataset_id, hunt_id)
if not rows:
return {"nodes": [], "edges": [], "stats": {}}
# Extract entities
entities: dict[str, set[str]] = defaultdict(set) # type -> set of values
row_entity_map: list[list[tuple[str, str]]] = [] # per-row list of (type, value)
# Field name patterns for entity extraction
HOST_FIELDS = re.compile(r"hostname|computer|host|machine", re.I)
USER_FIELDS = re.compile(r"user|account|logon.*name|subject.*name", re.I)
IP_FIELDS = re.compile(r"src.*ip|dst.*ip|ip.*addr|source.*ip|dest.*ip|remote.*addr", re.I)
PROC_FIELDS = re.compile(r"process.*name|image|parent.*image|executable|command", re.I)
for row in rows:
row_ents: list[tuple[str, str]] = []
for k, v in row.items():
if not v or str(v).strip() in ('', '-', 'N/A', 'None'):
continue
val = str(v).strip()
if HOST_FIELDS.search(k):
entities["host"].add(val)
row_ents.append(("host", val))
elif USER_FIELDS.search(k):
entities["user"].add(val)
row_ents.append(("user", val))
elif IP_FIELDS.search(k):
entities["ip"].add(val)
row_ents.append(("ip", val))
elif PROC_FIELDS.search(k):
# Clean process name
pname = val.split("\\")[-1].split("/")[-1][:60]
entities["process"].add(pname)
row_ents.append(("process", pname))
row_entity_map.append(row_ents)
# Map rows to techniques
row_texts = [" ".join(str(v).lower() for v in row.values() if v) for row in rows]
row_techniques: list[set[str]] = [set() for _ in rows]
tech_meta: dict[str, tuple[str, str]] = {}
for tech_id, tech_name, tactic, patterns in TECHNIQUE_PATTERNS:
compiled = [re.compile(p, re.I) for p in patterns]
tech_meta[tech_id] = (tech_name, tactic)
for i, text in enumerate(row_texts):
for pat in compiled:
if pat.search(text):
row_techniques[i].add(tech_id)
break
# Build graph
nodes: list[dict] = []
edges: list[dict] = []
node_ids: set[str] = set()
edge_counter: Counter = Counter()
# Entity nodes
TYPE_COLORS = {
"host": "#3b82f6",
"user": "#10b981",
"ip": "#8b5cf6",
"process": "#f59e0b",
"technique": "#ef4444",
"tactic": "#6366f1",
}
TYPE_SHAPES = {
"host": "roundrectangle",
"user": "ellipse",
"ip": "diamond",
"process": "hexagon",
"technique": "tag",
"tactic": "round-rectangle",
}
for ent_type, values in entities.items():
for val in list(values)[:50]: # limit nodes
nid = f"{ent_type}:{val}"
if nid not in node_ids:
node_ids.add(nid)
nodes.append({
"data": {
"id": nid,
"label": val[:40],
"type": ent_type,
"color": TYPE_COLORS.get(ent_type, "#666"),
"shape": TYPE_SHAPES.get(ent_type, "ellipse"),
},
})
# Technique nodes
seen_techniques: set[str] = set()
for tech_set in row_techniques:
seen_techniques.update(tech_set)
for tech_id in seen_techniques:
name, tactic = tech_meta.get(tech_id, (tech_id, "unknown"))
nid = f"technique:{tech_id}"
if nid not in node_ids:
node_ids.add(nid)
nodes.append({
"data": {
"id": nid,
"label": f"{tech_id}\n{name}",
"type": "technique",
"color": TYPE_COLORS["technique"],
"shape": TYPE_SHAPES["technique"],
"tactic": tactic,
},
})
# Edges: entity → technique (based on co-occurrence in rows)
for i, row_ents in enumerate(row_entity_map):
for ent_type, ent_val in row_ents:
for tech_id in row_techniques[i]:
src = f"{ent_type}:{ent_val}"
tgt = f"technique:{tech_id}"
if src in node_ids and tgt in node_ids:
edge_key = (src, tgt)
edge_counter[edge_key] += 1
# Edges: entity → entity (based on co-occurrence)
for row_ents in row_entity_map:
for j in range(len(row_ents)):
for k in range(j + 1, len(row_ents)):
src = f"{row_ents[j][0]}:{row_ents[j][1]}"
tgt = f"{row_ents[k][0]}:{row_ents[k][1]}"
if src in node_ids and tgt in node_ids and src != tgt:
edge_counter[(src, tgt)] += 1
# Build edge list (filter low-weight edges)
for (src, tgt), weight in edge_counter.most_common(500):
if weight < 1:
continue
edges.append({
"data": {
"source": src,
"target": tgt,
"weight": weight,
"label": str(weight) if weight > 2 else "",
},
})
return {
"nodes": nodes,
"edges": edges,
"stats": {
"total_nodes": len(nodes),
"total_edges": len(edges),
"entity_counts": {t: len(v) for t, v in entities.items()},
"techniques_found": len(seen_techniques),
},
}

View File

@@ -0,0 +1,252 @@
"""Network Picture — deduplicated host inventory built from dataset rows.
Scans all datasets in a hunt, extracts host-identifying fields from
normalized data, and groups by hostname (or src_ip fallback) to produce
a clean one-row-per-host inventory. Uses sets for deduplication —
if an IP appears 900 times, it shows once.
"""
import logging
from datetime import datetime
from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow
logger = logging.getLogger(__name__)
# Canonical column names we extract per row
_HOST_KEYS = ("hostname",)
_IP_KEYS = ("src_ip", "dst_ip", "ip_address")
_USER_KEYS = ("username",)
_OS_KEYS = ("os",)
_MAC_KEYS = ("mac_address",)
_PORT_SRC_KEYS = ("src_port",)
_PORT_DST_KEYS = ("dst_port",)
_PROTO_KEYS = ("protocol",)
_STATE_KEYS = ("connection_state",)
_TS_KEYS = ("timestamp",)
# Junk values to skip
_JUNK = frozenset({"", "-", "0.0.0.0", "::", "0", "127.0.0.1", "::1", "localhost", "unknown", "n/a", "none", "null"})
ROW_BATCH = 1000 # rows fetched per DB query
MAX_HOSTS = 1000 # hard cap on returned hosts
def _clean(val: Any) -> str:
"""Normalise a cell value to a clean string or empty."""
s = (val if isinstance(val, str) else str(val) if val is not None else "").strip()
return "" if s.lower() in _JUNK else s
def _try_parse_ts(val: str) -> datetime | None:
"""Best-effort timestamp parse (subset of common formats)."""
for fmt in (
"%Y-%m-%dT%H:%M:%S.%fZ",
"%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S",
):
try:
return datetime.strptime(val.strip(), fmt)
except ValueError:
continue
return None
class _HostBucket:
"""Mutable accumulator for a single host."""
__slots__ = (
"hostname", "ips", "users", "os_versions", "mac_addresses",
"protocols", "open_ports", "remote_targets", "datasets",
"connection_count", "first_seen", "last_seen",
)
def __init__(self, hostname: str):
self.hostname = hostname
self.ips: set[str] = set()
self.users: set[str] = set()
self.os_versions: set[str] = set()
self.mac_addresses: set[str] = set()
self.protocols: set[str] = set()
self.open_ports: set[str] = set()
self.remote_targets: set[str] = set()
self.datasets: set[str] = set()
self.connection_count: int = 0
self.first_seen: datetime | None = None
self.last_seen: datetime | None = None
def ingest(self, row: dict[str, Any], ds_name: str) -> None:
"""Merge one normalised row into this bucket."""
self.connection_count += 1
self.datasets.add(ds_name)
for k in _IP_KEYS:
v = _clean(row.get(k))
if v:
self.ips.add(v)
for k in _USER_KEYS:
v = _clean(row.get(k))
if v:
self.users.add(v)
for k in _OS_KEYS:
v = _clean(row.get(k))
if v:
self.os_versions.add(v)
for k in _MAC_KEYS:
v = _clean(row.get(k))
if v:
self.mac_addresses.add(v)
for k in _PROTO_KEYS:
v = _clean(row.get(k))
if v:
self.protocols.add(v.upper())
# Open ports = local (src) ports
for k in _PORT_SRC_KEYS:
v = _clean(row.get(k))
if v and v != "0":
self.open_ports.add(v)
# Remote targets = dst IPs
dst = _clean(row.get("dst_ip"))
if dst:
self.remote_targets.add(dst)
# Timestamps
for k in _TS_KEYS:
v = _clean(row.get(k))
if v:
ts = _try_parse_ts(v)
if ts:
if self.first_seen is None or ts < self.first_seen:
self.first_seen = ts
if self.last_seen is None or ts > self.last_seen:
self.last_seen = ts
def to_dict(self) -> dict[str, Any]:
return {
"hostname": self.hostname,
"ips": sorted(self.ips),
"users": sorted(self.users),
"os": sorted(self.os_versions),
"mac_addresses": sorted(self.mac_addresses),
"protocols": sorted(self.protocols),
"open_ports": sorted(self.open_ports, key=lambda p: int(p) if p.isdigit() else 0),
"remote_targets": sorted(self.remote_targets),
"datasets": sorted(self.datasets),
"connection_count": self.connection_count,
"first_seen": self.first_seen.isoformat() if self.first_seen else None,
"last_seen": self.last_seen.isoformat() if self.last_seen else None,
}
async def build_network_picture(
db: AsyncSession,
hunt_id: str,
) -> dict[str, Any]:
"""Build a deduplicated host inventory for all datasets in a hunt.
Returns:
{
"hosts": [ {hostname, ips[], users[], os[], ...}, ... ],
"summary": { total_hosts, total_connections, total_unique_ips, datasets_scanned }
}
"""
# 1. Get all datasets in this hunt
ds_result = await db.execute(
select(Dataset)
.where(Dataset.hunt_id == hunt_id)
.order_by(Dataset.created_at)
)
ds_list: Sequence[Dataset] = ds_result.scalars().all()
if not ds_list:
return {
"hosts": [],
"summary": {
"total_hosts": 0,
"total_connections": 0,
"total_unique_ips": 0,
"datasets_scanned": 0,
},
}
# 2. Stream rows and aggregate into host buckets
buckets: dict[str, _HostBucket] = {} # key = uppercase hostname or IP
for ds in ds_list:
ds_name = ds.name or ds.filename
offset = 0
while True:
stmt = (
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.limit(ROW_BATCH)
.offset(offset)
)
result = await db.execute(stmt)
rows: Sequence[DatasetRow] = result.scalars().all()
if not rows:
break
for dr in rows:
norm = dr.normalized_data or dr.data or {}
# Determine grouping key: hostname preferred, else src_ip/ip_address
host_val = ""
for k in _HOST_KEYS:
host_val = _clean(norm.get(k))
if host_val:
break
if not host_val:
for k in ("src_ip", "ip_address"):
host_val = _clean(norm.get(k))
if host_val:
break
if not host_val:
# Row has no host identifier — skip
continue
bucket_key = host_val.upper()
if bucket_key not in buckets:
buckets[bucket_key] = _HostBucket(host_val)
buckets[bucket_key].ingest(norm, ds_name)
offset += ROW_BATCH
# 3. Convert to sorted list (by connection count descending)
hosts_raw = sorted(buckets.values(), key=lambda b: b.connection_count, reverse=True)
if len(hosts_raw) > MAX_HOSTS:
hosts_raw = hosts_raw[:MAX_HOSTS]
hosts = [b.to_dict() for b in hosts_raw]
# 4. Summary stats
all_ips: set[str] = set()
total_conns = 0
for b in hosts_raw:
all_ips.update(b.ips)
total_conns += b.connection_count
return {
"hosts": hosts,
"summary": {
"total_hosts": len(hosts),
"total_connections": total_conns,
"total_unique_ips": len(all_ips),
"datasets_scanned": len(ds_list),
},
}

View File

@@ -23,12 +23,12 @@ COLUMN_MAPPINGS: list[tuple[str, str]] = [
# Operating system
(r"^(os|operating_?system|os_?version|os_?name|platform|os_?type)$", "os"),
# Source / destination IPs
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress)$", "src_ip"),
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress)$", "dst_ip"),
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress|sourceip|laddr\.?ip)$", "src_ip"),
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress|destination_?ip|destinationip|raddr\.?ip)$", "dst_ip"),
(r"^(ip_?address|ipaddress|ip)$", "ip_address"),
# Ports
(r"^(source_?port|src_?port|localport)$", "src_port"),
(r"^(dest_?port|dst_?port|remoteport|destinationport)$", "dst_port"),
(r"^(source_?port|src_?port|localport|laddr\.?port)$", "src_port"),
(r"^(dest_?port|dst_?port|remoteport|destinationport|raddr\.?port)$", "dst_port"),
# Process info
(r"^(process_?name|name|image|exe|executable|binary)$", "process_name"),
(r"^(pid|process_?id)$", "pid"),
@@ -51,6 +51,10 @@ COLUMN_MAPPINGS: list[tuple[str, str]] = [
(r"^(protocol|proto)$", "protocol"),
(r"^(domain|dns_?name|query_?name|queriedname)$", "domain"),
(r"^(url|uri|request_?url)$", "url"),
# MAC address
(r"^(mac|mac_?address|physical_?address|mac_?addr|hw_?addr|ethernet)$", "mac_address"),
# Connection state (netstat)
(r"^(state|status|tcp_?state|conn_?state)$", "connection_state"),
# Event info
(r"^(event_?id|eventid|eid)$", "event_id"),
(r"^(event_?type|eventtype|category|action)$", "event_type"),
@@ -120,13 +124,27 @@ def detect_ioc_columns(
"domain": "domain",
}
# Canonical names that should NEVER be treated as IOCs even if values
# match a pattern (e.g. process_name "svchost.exe" matching domain regex).
_non_ioc_canonicals = frozenset({
"process_name", "file_name", "file_path", "command_line",
"parent_command_line", "description", "event_type", "registry_key",
"registry_value", "severity", "os",
"title", "netmask", "gateway", "connection_state",
})
for col in columns:
canonical = column_mapping.get(col, "")
# Skip columns whose canonical meaning is obviously not an IOC
if canonical in _non_ioc_canonicals:
continue
col_type = column_types.get(col)
if col_type in ioc_type_map:
ioc_columns[col] = ioc_type_map[col_type]
# Also check canonical name
canonical = column_mapping.get(col, "")
if canonical in ("src_ip", "dst_ip", "ip_address"):
ioc_columns[col] = "ip"
elif canonical == "hash_md5":
@@ -139,6 +157,8 @@ def detect_ioc_columns(
ioc_columns[col] = "domain"
elif canonical == "url":
ioc_columns[col] = "url"
elif canonical == "hostname":
ioc_columns[col] = "hostname"
return ioc_columns

View File

@@ -0,0 +1,269 @@
"""Investigation Notebook & Playbook Engine for ThreatHunt.
Notebooks: Analyst-facing, cell-based documents (markdown + code cells)
stored as JSON in the database. Each cell can contain free-form
markdown notes *or* a structured query/command that the backend
evaluates against datasets.
Playbooks: Pre-defined, step-by-step investigation workflows. Each step
defines an action (query, analyze, enrich, tag) and expected
outcomes. Analysts can run through playbooks interactively or
trigger them automatically on new alerts.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Optional
logger = logging.getLogger(__name__)
# ── Notebook helpers ──────────────────────────────────────────────────
@dataclass
class NotebookCell:
id: str
cell_type: str # markdown | query | code
source: str
output: Optional[str] = None
metadata: dict = field(default_factory=dict)
def validate_notebook_cells(cells: list[dict]) -> list[dict]:
"""Ensure each cell has required keys."""
cleaned: list[dict] = []
for i, c in enumerate(cells):
cleaned.append({
"id": c.get("id", f"cell-{i}"),
"cell_type": c.get("cell_type", "markdown"),
"source": c.get("source", ""),
"output": c.get("output"),
"metadata": c.get("metadata", {}),
})
return cleaned
# ── Built-in Playbook Templates ──────────────────────────────────────
BUILT_IN_PLAYBOOKS: list[dict[str, Any]] = [
{
"name": "Suspicious Process Investigation",
"description": "Step-by-step investigation of a potentially malicious process execution.",
"category": "incident_response",
"tags": ["process", "malware", "T1059"],
"steps": [
{
"order": 1,
"title": "Identify the suspicious process",
"description": "Search for the process name/PID across all datasets. Note the command line, parent, and user context.",
"action": "search",
"action_config": {"fields": ["process_name", "command_line", "parent_process_name", "username"]},
"expected_outcome": "Process details, parent chain, and execution context identified.",
},
{
"order": 2,
"title": "Build process tree",
"description": "View the full parent→child process tree to understand the execution chain.",
"action": "process_tree",
"action_config": {},
"expected_outcome": "Complete process lineage showing how the suspicious process was spawned.",
},
{
"order": 3,
"title": "Check network connections",
"description": "Search for network events associated with the same host and timeframe.",
"action": "search",
"action_config": {"fields": ["src_ip", "dst_ip", "dst_port", "protocol"]},
"expected_outcome": "Network connections revealing potential C2 or data exfiltration.",
},
{
"order": 4,
"title": "Run analyzers",
"description": "Execute the suspicious_commands and entropy analyzers against the dataset.",
"action": "analyze",
"action_config": {"analyzers": ["suspicious_commands", "entropy"]},
"expected_outcome": "Automated detection of known-bad patterns.",
},
{
"order": 5,
"title": "Map to MITRE ATT&CK",
"description": "Check which MITRE techniques the process behavior maps to.",
"action": "mitre_map",
"action_config": {},
"expected_outcome": "MITRE technique mappings for the suspicious activity.",
},
{
"order": 6,
"title": "Document findings & create case",
"description": "Summarize investigation findings, annotate key evidence, and create a case if warranted.",
"action": "create_case",
"action_config": {},
"expected_outcome": "Investigation documented with annotations and optionally escalated.",
},
],
},
{
"name": "Lateral Movement Hunt",
"description": "Systematic hunt for lateral movement indicators across the environment.",
"category": "threat_hunting",
"tags": ["lateral_movement", "T1021", "T1047"],
"steps": [
{
"order": 1,
"title": "Search for remote access tools",
"description": "Look for PsExec, WMI, WinRM, RDP, and SSH usage across datasets.",
"action": "search",
"action_config": {"query": "psexec|wmic|winrm|rdp|ssh"},
"expected_outcome": "Identify all remote access tool usage.",
},
{
"order": 2,
"title": "Analyze authentication events",
"description": "Run the auth anomaly analyzer to find brute force, unusual logon types.",
"action": "analyze",
"action_config": {"analyzers": ["auth_anomaly"]},
"expected_outcome": "Authentication anomalies detected.",
},
{
"order": 3,
"title": "Check network anomalies",
"description": "Run network anomaly analyzer for beaconing and suspicious connections.",
"action": "analyze",
"action_config": {"analyzers": ["network_anomaly"]},
"expected_outcome": "Beaconing or unusual network patterns identified.",
},
{
"order": 4,
"title": "Build knowledge graph",
"description": "Visualize entity relationships to identify pivot paths.",
"action": "knowledge_graph",
"action_config": {},
"expected_outcome": "Entity relationship graph showing lateral movement paths.",
},
{
"order": 5,
"title": "Document and escalate",
"description": "Create annotations for key findings and escalate to case if needed.",
"action": "create_case",
"action_config": {"tags": ["lateral_movement"]},
"expected_outcome": "Findings documented and case created.",
},
],
},
{
"name": "Data Exfiltration Check",
"description": "Investigate potential data exfiltration activity.",
"category": "incident_response",
"tags": ["exfiltration", "T1048", "T1567"],
"steps": [
{
"order": 1,
"title": "Identify large transfers",
"description": "Search for network events with unusually high byte counts.",
"action": "analyze",
"action_config": {"analyzers": ["network_anomaly"], "config": {"large_transfer_threshold": 5000000}},
"expected_outcome": "Large data transfers identified.",
},
{
"order": 2,
"title": "Check DNS anomalies",
"description": "Look for DNS tunneling or unusual DNS query patterns.",
"action": "search",
"action_config": {"fields": ["dns_query", "query_length"]},
"expected_outcome": "Suspicious DNS activity identified.",
},
{
"order": 3,
"title": "Timeline analysis",
"description": "Examine the timeline for data staging and exfiltration windows.",
"action": "timeline",
"action_config": {},
"expected_outcome": "Time windows of suspicious activity identified.",
},
{
"order": 4,
"title": "Correlate with process activity",
"description": "Match network exfiltration with process execution times.",
"action": "search",
"action_config": {"fields": ["process_name", "dst_ip", "bytes_sent"]},
"expected_outcome": "Process responsible for data transfer identified.",
},
{
"order": 5,
"title": "MITRE mapping & documentation",
"description": "Map findings to MITRE exfiltration techniques and document.",
"action": "mitre_map",
"action_config": {},
"expected_outcome": "Complete exfiltration investigation documented.",
},
],
},
{
"name": "Ransomware Triage",
"description": "Rapid triage of potential ransomware activity.",
"category": "incident_response",
"tags": ["ransomware", "T1486", "T1490"],
"steps": [
{
"order": 1,
"title": "Search for ransomware indicators",
"description": "Look for shadow copy deletion, boot config changes, encryption activity.",
"action": "search",
"action_config": {"query": "vssadmin|bcdedit|cipher|.encrypted|.locked|ransom"},
"expected_outcome": "Ransomware indicators identified.",
},
{
"order": 2,
"title": "Run all analyzers",
"description": "Execute all analyzers to get comprehensive threat picture.",
"action": "analyze",
"action_config": {},
"expected_outcome": "Full automated analysis of ransomware indicators.",
},
{
"order": 3,
"title": "Check persistence mechanisms",
"description": "Look for persistence that may indicate pre-ransomware staging.",
"action": "analyze",
"action_config": {"analyzers": ["persistence"]},
"expected_outcome": "Persistence mechanisms identified.",
},
{
"order": 4,
"title": "LLM deep analysis",
"description": "Run deep LLM analysis for comprehensive ransomware assessment.",
"action": "llm_analyze",
"action_config": {"mode": "deep", "focus": "ransomware"},
"expected_outcome": "AI-powered ransomware analysis with recommendations.",
},
{
"order": 5,
"title": "Create critical case",
"description": "Immediately create a critical-severity case for the ransomware incident.",
"action": "create_case",
"action_config": {"severity": "critical", "tags": ["ransomware"]},
"expected_outcome": "Critical case created for incident response.",
},
],
},
]
def get_builtin_playbooks() -> list[dict]:
"""Return list of all built-in playbook templates."""
return BUILT_IN_PLAYBOOKS
def get_playbook_template(name: str) -> dict | None:
"""Get a specific built-in playbook by name."""
for pb in BUILT_IN_PLAYBOOKS:
if pb["name"] == name:
return pb
return None

View File

@@ -0,0 +1,447 @@
"""Process tree and storyline graph builder.
Extracts parent→child process relationships from dataset rows and builds
hierarchical trees. Also builds attack-storyline graphs connecting events
by host → process → network activity → file activity chains.
"""
import logging
from collections import defaultdict
from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db.models import Dataset, DatasetRow
logger = logging.getLogger(__name__)
# ── Helpers ───────────────────────────────────────────────────────────
_JUNK = frozenset({"", "N/A", "n/a", "-", "", "null", "None", "none", "unknown"})
def _clean(val: Any) -> str | None:
"""Return cleaned string or None for junk values."""
if val is None:
return None
s = str(val).strip()
return None if s in _JUNK else s
# ── Process Tree ──────────────────────────────────────────────────────
class ProcessNode:
"""A single process in the tree."""
__slots__ = (
"pid", "ppid", "name", "command_line", "username", "hostname",
"timestamp", "dataset_name", "row_index", "children", "extra",
)
def __init__(self, **kw: Any):
self.pid: str = kw.get("pid", "")
self.ppid: str = kw.get("ppid", "")
self.name: str = kw.get("name", "")
self.command_line: str = kw.get("command_line", "")
self.username: str = kw.get("username", "")
self.hostname: str = kw.get("hostname", "")
self.timestamp: str = kw.get("timestamp", "")
self.dataset_name: str = kw.get("dataset_name", "")
self.row_index: int = kw.get("row_index", -1)
self.children: list["ProcessNode"] = []
self.extra: dict = kw.get("extra", {})
def to_dict(self) -> dict:
return {
"pid": self.pid,
"ppid": self.ppid,
"name": self.name,
"command_line": self.command_line,
"username": self.username,
"hostname": self.hostname,
"timestamp": self.timestamp,
"dataset_name": self.dataset_name,
"row_index": self.row_index,
"children": [c.to_dict() for c in self.children],
"extra": self.extra,
}
async def build_process_tree(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
hostname_filter: str | None = None,
) -> list[dict]:
"""Build process trees from dataset rows.
Returns a list of root-level process nodes (forest).
"""
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
if not rows:
return []
# Group processes by (hostname, pid) → node
nodes_by_key: dict[tuple[str, str], ProcessNode] = {}
nodes_list: list[ProcessNode] = []
for row_obj in rows:
data = row_obj.normalized_data or row_obj.data
pid = _clean(data.get("pid"))
if not pid:
continue
ppid = _clean(data.get("ppid")) or ""
hostname = _clean(data.get("hostname")) or "unknown"
if hostname_filter and hostname.lower() != hostname_filter.lower():
continue
node = ProcessNode(
pid=pid,
ppid=ppid,
name=_clean(data.get("process_name")) or _clean(data.get("name")) or "",
command_line=_clean(data.get("command_line")) or "",
username=_clean(data.get("username")) or "",
hostname=hostname,
timestamp=_clean(data.get("timestamp")) or "",
dataset_name=row_obj.dataset.name if row_obj.dataset else "",
row_index=row_obj.row_index,
extra={
k: str(v)
for k, v in data.items()
if k not in {"pid", "ppid", "process_name", "name", "command_line",
"username", "hostname", "timestamp"}
and v is not None and str(v).strip() not in _JUNK
},
)
key = (hostname, pid)
# Keep the first occurrence (earlier in data) or overwrite if deeper info
if key not in nodes_by_key:
nodes_by_key[key] = node
nodes_list.append(node)
# Link parent → child
roots: list[ProcessNode] = []
for node in nodes_list:
parent_key = (node.hostname, node.ppid)
parent = nodes_by_key.get(parent_key)
if parent and parent is not node:
parent.children.append(node)
else:
roots.append(node)
return [r.to_dict() for r in roots]
# ── Storyline / Attack Graph ─────────────────────────────────────────
async def build_storyline(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
hostname_filter: str | None = None,
) -> dict:
"""Build a CrowdStrike-style storyline graph.
Nodes represent events (process start, network connection, file write, etc.)
Edges represent causal / temporal relationships.
Returns a Cytoscape-compatible elements dict {nodes: [...], edges: [...]}.
"""
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
if not rows:
return {"nodes": [], "edges": [], "summary": {}}
nodes: list[dict] = []
edges: list[dict] = []
seen_ids: set[str] = set()
host_events: dict[str, list[dict]] = defaultdict(list)
for row_obj in rows:
data = row_obj.normalized_data or row_obj.data
hostname = _clean(data.get("hostname")) or "unknown"
if hostname_filter and hostname.lower() != hostname_filter.lower():
continue
event_type = _classify_event(data)
node_id = f"{row_obj.dataset_id}_{row_obj.row_index}"
if node_id in seen_ids:
continue
seen_ids.add(node_id)
label = _build_label(data, event_type)
severity = _estimate_severity(data, event_type)
node = {
"data": {
"id": node_id,
"label": label,
"event_type": event_type,
"hostname": hostname,
"timestamp": _clean(data.get("timestamp")) or "",
"pid": _clean(data.get("pid")) or "",
"ppid": _clean(data.get("ppid")) or "",
"process_name": _clean(data.get("process_name")) or "",
"command_line": _clean(data.get("command_line")) or "",
"username": _clean(data.get("username")) or "",
"src_ip": _clean(data.get("src_ip")) or "",
"dst_ip": _clean(data.get("dst_ip")) or "",
"dst_port": _clean(data.get("dst_port")) or "",
"file_path": _clean(data.get("file_path")) or "",
"severity": severity,
"dataset_id": row_obj.dataset_id,
"row_index": row_obj.row_index,
},
}
nodes.append(node)
host_events[hostname].append(node["data"])
# Build edges: parent→child (by pid/ppid) and temporal sequence per host
pid_lookup: dict[tuple[str, str], str] = {} # (host, pid) → node_id
for node in nodes:
d = node["data"]
if d["pid"]:
pid_lookup[(d["hostname"], d["pid"])] = d["id"]
for node in nodes:
d = node["data"]
if d["ppid"]:
parent_id = pid_lookup.get((d["hostname"], d["ppid"]))
if parent_id and parent_id != d["id"]:
edges.append({
"data": {
"id": f"e_{parent_id}_{d['id']}",
"source": parent_id,
"target": d["id"],
"relationship": "spawned",
}
})
# Temporal edges within each host (sorted by timestamp)
for hostname, events in host_events.items():
sorted_events = sorted(events, key=lambda e: e.get("timestamp", ""))
for i in range(len(sorted_events) - 1):
src = sorted_events[i]
tgt = sorted_events[i + 1]
edge_id = f"t_{src['id']}_{tgt['id']}"
# Avoid duplicate edges
if not any(e["data"]["id"] == edge_id for e in edges):
edges.append({
"data": {
"id": edge_id,
"source": src["id"],
"target": tgt["id"],
"relationship": "temporal",
}
})
# Summary stats
type_counts: dict[str, int] = defaultdict(int)
for n in nodes:
type_counts[n["data"]["event_type"]] += 1
summary = {
"total_events": len(nodes),
"total_edges": len(edges),
"hosts": list(host_events.keys()),
"event_types": dict(type_counts),
}
return {"nodes": nodes, "edges": edges, "summary": summary}
# ── Risk scoring for dashboard ────────────────────────────────────────
async def compute_risk_scores(
db: AsyncSession,
hunt_id: str | None = None,
) -> dict:
"""Compute per-host risk scores from anomaly signals in datasets.
Returns {hosts: [{hostname, score, signals, ...}], overall_score, ...}
"""
rows = await _fetch_rows(db, hunt_id=hunt_id)
if not rows:
return {"hosts": [], "overall_score": 0, "total_events": 0,
"severity_breakdown": {}}
host_signals: dict[str, dict] = defaultdict(
lambda: {"hostname": "", "score": 0, "signals": [],
"event_count": 0, "process_count": 0,
"network_count": 0, "file_count": 0}
)
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
for row_obj in rows:
data = row_obj.normalized_data or row_obj.data
hostname = _clean(data.get("hostname")) or "unknown"
entry = host_signals[hostname]
entry["hostname"] = hostname
entry["event_count"] += 1
event_type = _classify_event(data)
severity = _estimate_severity(data, event_type)
severity_counts[severity] = severity_counts.get(severity, 0) + 1
# Count event types
if event_type == "process":
entry["process_count"] += 1
elif event_type == "network":
entry["network_count"] += 1
elif event_type == "file":
entry["file_count"] += 1
# Risk signals
cmd = (_clean(data.get("command_line")) or "").lower()
proc = (_clean(data.get("process_name")) or "").lower()
# Detect suspicious patterns
sus_patterns = [
("powershell -enc", 8, "Encoded PowerShell"),
("invoke-expression", 7, "PowerShell IEX"),
("invoke-webrequest", 6, "PowerShell WebRequest"),
("certutil -urlcache", 8, "Certutil download"),
("bitsadmin /transfer", 7, "BITS transfer"),
("regsvr32 /s /n /u", 8, "Regsvr32 squiblydoo"),
("mshta ", 7, "MSHTA execution"),
("wmic process", 6, "WMIC process enum"),
("net user", 5, "User enumeration"),
("whoami", 4, "Whoami recon"),
("mimikatz", 10, "Mimikatz detected"),
("procdump", 7, "Process dumping"),
("psexec", 7, "PsExec lateral movement"),
]
for pattern, score_add, signal_name in sus_patterns:
if pattern in cmd or pattern in proc:
entry["score"] += score_add
if signal_name not in entry["signals"]:
entry["signals"].append(signal_name)
# External connections score
dst_ip = _clean(data.get("dst_ip")) or ""
if dst_ip and not dst_ip.startswith(("10.", "192.168.", "172.")):
entry["score"] += 1
if "External connections" not in entry["signals"]:
entry["signals"].append("External connections")
# Normalize scores (0-100)
max_score = max((h["score"] for h in host_signals.values()), default=1)
if max_score > 0:
for entry in host_signals.values():
entry["score"] = min(round((entry["score"] / max_score) * 100), 100)
hosts = sorted(host_signals.values(), key=lambda h: h["score"], reverse=True)
overall = round(sum(h["score"] for h in hosts) / max(len(hosts), 1))
return {
"hosts": hosts,
"overall_score": overall,
"total_events": sum(h["event_count"] for h in hosts),
"severity_breakdown": severity_counts,
}
# ── Internal helpers ──────────────────────────────────────────────────
async def _fetch_rows(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
limit: int = 50_000,
) -> Sequence[DatasetRow]:
"""Fetch dataset rows, optionally filtered by dataset or hunt."""
stmt = (
select(DatasetRow)
.join(Dataset)
.options(selectinload(DatasetRow.dataset))
)
if dataset_id:
stmt = stmt.where(DatasetRow.dataset_id == dataset_id)
elif hunt_id:
stmt = stmt.where(Dataset.hunt_id == hunt_id)
else:
# No filter — limit to prevent OOM
pass
stmt = stmt.order_by(DatasetRow.row_index).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
def _classify_event(data: dict) -> str:
"""Classify a row as process / network / file / registry / other."""
if _clean(data.get("pid")) or _clean(data.get("process_name")):
if _clean(data.get("dst_ip")) or _clean(data.get("dst_port")):
return "network"
if _clean(data.get("file_path")):
return "file"
return "process"
if _clean(data.get("dst_ip")) or _clean(data.get("src_ip")):
return "network"
if _clean(data.get("file_path")):
return "file"
if _clean(data.get("registry_key")):
return "registry"
return "other"
def _build_label(data: dict, event_type: str) -> str:
"""Build a concise node label for storyline display."""
name = _clean(data.get("process_name")) or ""
pid = _clean(data.get("pid")) or ""
dst = _clean(data.get("dst_ip")) or ""
port = _clean(data.get("dst_port")) or ""
fpath = _clean(data.get("file_path")) or ""
if event_type == "process":
return f"{name} (PID {pid})" if pid else name or "process"
elif event_type == "network":
target = f"{dst}:{port}" if dst and port else dst or port
return f"{name}{target}" if name else target or "network"
elif event_type == "file":
fname = fpath.split("\\")[-1].split("/")[-1] if fpath else ""
return f"{name}{fname}" if name else fname or "file"
elif event_type == "registry":
return _clean(data.get("registry_key")) or "registry"
return name or "event"
def _estimate_severity(data: dict, event_type: str) -> str:
"""Rough heuristic severity estimate."""
cmd = (_clean(data.get("command_line")) or "").lower()
proc = (_clean(data.get("process_name")) or "").lower()
# Critical indicators
critical_kw = ["mimikatz", "cobalt", "meterpreter", "empire", "bloodhound"]
if any(k in cmd or k in proc for k in critical_kw):
return "critical"
# High indicators
high_kw = ["powershell -enc", "certutil -urlcache", "regsvr32", "mshta",
"bitsadmin", "psexec", "procdump"]
if any(k in cmd for k in high_kw):
return "high"
# Medium indicators
medium_kw = ["invoke-", "wmic", "net user", "net group", "schtasks",
"reg add", "sc create"]
if any(k in cmd for k in medium_kw):
return "medium"
# Low: recon
low_kw = ["whoami", "ipconfig", "systeminfo", "tasklist", "netstat"]
if any(k in cmd for k in low_kw):
return "low"
return "info"

View File

@@ -0,0 +1,254 @@
"""Timeline and field-statistics service.
Provides temporal histogram bins and per-field distribution stats
for dataset rows — used by the TimelineScrubber and QueryBar components.
"""
import logging
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow
logger = logging.getLogger(__name__)
# ── Timeline bins ─────────────────────────────────────────────────────
async def build_timeline_bins(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
bins: int = 60,
) -> dict:
"""Create histogram bins of events over time.
Returns {bins: [{start, end, count, events_by_type}], total, range}.
"""
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
if not rows:
return {"bins": [], "total": 0, "range": None}
# Extract timestamps
events: list[dict] = []
for r in rows:
data = r.normalized_data or r.data
ts_str = data.get("timestamp", "")
if not ts_str:
continue
ts = _parse_ts(str(ts_str))
if ts:
events.append({
"timestamp": ts,
"event_type": _classify_type(data),
"hostname": data.get("hostname", ""),
})
if not events:
return {"bins": [], "total": len(rows), "range": None}
events.sort(key=lambda e: e["timestamp"])
ts_min = events[0]["timestamp"]
ts_max = events[-1]["timestamp"]
if ts_min == ts_max:
return {
"bins": [{"start": ts_min.isoformat(), "end": ts_max.isoformat(),
"count": len(events), "events_by_type": {}}],
"total": len(events),
"range": {"start": ts_min.isoformat(), "end": ts_max.isoformat()},
}
delta = (ts_max - ts_min) / bins
result_bins: list[dict] = []
for i in range(bins):
bin_start = ts_min + delta * i
bin_end = ts_min + delta * (i + 1)
bin_events = [e for e in events
if bin_start <= e["timestamp"] < bin_end
or (i == bins - 1 and e["timestamp"] == ts_max)]
type_counts: dict[str, int] = Counter(e["event_type"] for e in bin_events)
result_bins.append({
"start": bin_start.isoformat(),
"end": bin_end.isoformat(),
"count": len(bin_events),
"events_by_type": dict(type_counts),
})
return {
"bins": result_bins,
"total": len(events),
"range": {"start": ts_min.isoformat(), "end": ts_max.isoformat()},
}
# ── Field stats ───────────────────────────────────────────────────────
async def compute_field_stats(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
fields: list[str] | None = None,
top_n: int = 20,
) -> dict:
"""Compute per-field value distributions.
Returns {fields: {field_name: {total, unique, top: [{value, count}]}}}
"""
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
if not rows:
return {"fields": {}, "total_rows": 0}
# Determine which fields to analyze
sample_data = rows[0].normalized_data or rows[0].data
all_fields = list(sample_data.keys())
target_fields = fields if fields else all_fields[:30]
stats: dict[str, dict] = {}
for field in target_fields:
values = []
for r in rows:
data = r.normalized_data or r.data
v = data.get(field)
if v is not None and str(v).strip() not in ("", "N/A", "n/a", "-", "None"):
values.append(str(v))
counter = Counter(values)
top = [{"value": v, "count": c} for v, c in counter.most_common(top_n)]
stats[field] = {
"total": len(values),
"unique": len(counter),
"top": top,
}
return {
"fields": stats,
"total_rows": len(rows),
"available_fields": all_fields,
}
# ── Row search with filters ──────────────────────────────────────────
async def search_rows(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
query: str = "",
filters: dict[str, str] | None = None,
time_start: str | None = None,
time_end: str | None = None,
limit: int = 500,
offset: int = 0,
) -> dict:
"""Search/filter dataset rows.
Supports:
- Free-text search across all fields
- Field-specific filters {field: value}
- Time range filters
"""
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id, limit=50000)
if not rows:
return {"rows": [], "total": 0, "offset": offset, "limit": limit}
results: list[dict] = []
ts_start = _parse_ts(time_start) if time_start else None
ts_end = _parse_ts(time_end) if time_end else None
for r in rows:
data = r.normalized_data or r.data
# Time filter
if ts_start or ts_end:
ts = _parse_ts(str(data.get("timestamp", "")))
if ts:
if ts_start and ts < ts_start:
continue
if ts_end and ts > ts_end:
continue
# Field filters
if filters:
match = True
for field, value in filters.items():
field_val = str(data.get(field, "")).lower()
if value.lower() not in field_val:
match = False
break
if not match:
continue
# Free-text search
if query:
q = query.lower()
found = any(q in str(v).lower() for v in data.values())
if not found:
continue
results.append(data)
total = len(results)
paged = results[offset:offset + limit]
return {"rows": paged, "total": total, "offset": offset, "limit": limit}
# ── Internal helpers ──────────────────────────────────────────────────
async def _fetch_rows(
db: AsyncSession,
dataset_id: str | None = None,
hunt_id: str | None = None,
limit: int = 50_000,
) -> Sequence[DatasetRow]:
stmt = select(DatasetRow).join(Dataset)
if dataset_id:
stmt = stmt.where(DatasetRow.dataset_id == dataset_id)
elif hunt_id:
stmt = stmt.where(Dataset.hunt_id == hunt_id)
stmt = stmt.order_by(DatasetRow.row_index).limit(limit)
result = await db.execute(stmt)
return result.scalars().all()
def _parse_ts(ts_str: str | None) -> datetime | None:
"""Best-effort timestamp parsing."""
if not ts_str:
return None
for fmt in (
"%Y-%m-%dT%H:%M:%S.%fZ",
"%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S",
"%m/%d/%Y %H:%M:%S",
"%m/%d/%Y %I:%M:%S %p",
):
try:
return datetime.strptime(ts_str.strip(), fmt)
except (ValueError, AttributeError):
continue
return None
def _classify_type(data: dict) -> str:
if data.get("pid") or data.get("process_name"):
if data.get("dst_ip") or data.get("dst_port"):
return "network"
return "process"
if data.get("dst_ip") or data.get("src_ip"):
return "network"
if data.get("file_path"):
return "file"
return "other"