mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
version 0.4.0
This commit is contained in:
404
backend/app/api/routes/alerts.py
Normal file
404
backend/app/api/routes/alerts.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""API routes for alerts — CRUD, analyze triggers, and alert rules."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Alert, AlertRule, _new_id, _utcnow
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.analyzers import (
|
||||
get_available_analyzers,
|
||||
get_analyzer,
|
||||
run_all_analyzers,
|
||||
AlertCandidate,
|
||||
)
|
||||
from app.services.process_tree import _fetch_rows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AlertUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class RuleCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
analyzer: str
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: bool = True
|
||||
hunt_id: Optional[str] = None
|
||||
|
||||
|
||||
class RuleUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
analyzers: Optional[list[str]] = None # None = run all
|
||||
config: Optional[dict] = None
|
||||
auto_create: bool = True # automatically persist alerts
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _alert_to_dict(a: Alert) -> dict:
|
||||
return {
|
||||
"id": a.id,
|
||||
"title": a.title,
|
||||
"description": a.description,
|
||||
"severity": a.severity,
|
||||
"status": a.status,
|
||||
"analyzer": a.analyzer,
|
||||
"score": a.score,
|
||||
"evidence": a.evidence or [],
|
||||
"mitre_technique": a.mitre_technique,
|
||||
"tags": a.tags or [],
|
||||
"hunt_id": a.hunt_id,
|
||||
"dataset_id": a.dataset_id,
|
||||
"case_id": a.case_id,
|
||||
"assignee": a.assignee,
|
||||
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
|
||||
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
|
||||
"created_at": a.created_at.isoformat() if a.created_at else None,
|
||||
"updated_at": a.updated_at.isoformat() if a.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _rule_to_dict(r: AlertRule) -> dict:
|
||||
return {
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"analyzer": r.analyzer,
|
||||
"config": r.config,
|
||||
"severity_override": r.severity_override,
|
||||
"enabled": r.enabled,
|
||||
"hunt_id": r.hunt_id,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Alert CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List alerts")
|
||||
async def list_alerts(
|
||||
status: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
analyzer: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
dataset_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Alert)
|
||||
count_stmt = select(func.count(Alert.id))
|
||||
if status:
|
||||
stmt = stmt.where(Alert.status == status)
|
||||
count_stmt = count_stmt.where(Alert.status == status)
|
||||
if severity:
|
||||
stmt = stmt.where(Alert.severity == severity)
|
||||
count_stmt = count_stmt.where(Alert.severity == severity)
|
||||
if analyzer:
|
||||
stmt = stmt.where(Alert.analyzer == analyzer)
|
||||
count_stmt = count_stmt.where(Alert.analyzer == analyzer)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Alert.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Alert.hunt_id == hunt_id)
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Alert.dataset_id == dataset_id)
|
||||
count_stmt = count_stmt.where(Alert.dataset_id == dataset_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Alert.score), desc(Alert.created_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"alerts": [_alert_to_dict(a) for a in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/stats", summary="Alert statistics dashboard")
|
||||
async def alert_stats(
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return aggregated alert statistics."""
|
||||
base = select(Alert)
|
||||
if hunt_id:
|
||||
base = base.where(Alert.hunt_id == hunt_id)
|
||||
|
||||
# Severity breakdown
|
||||
sev_stmt = select(Alert.severity, func.count(Alert.id)).group_by(Alert.severity)
|
||||
if hunt_id:
|
||||
sev_stmt = sev_stmt.where(Alert.hunt_id == hunt_id)
|
||||
sev_rows = (await db.execute(sev_stmt)).all()
|
||||
severity_counts = {s: c for s, c in sev_rows}
|
||||
|
||||
# Status breakdown
|
||||
status_stmt = select(Alert.status, func.count(Alert.id)).group_by(Alert.status)
|
||||
if hunt_id:
|
||||
status_stmt = status_stmt.where(Alert.hunt_id == hunt_id)
|
||||
status_rows = (await db.execute(status_stmt)).all()
|
||||
status_counts = {s: c for s, c in status_rows}
|
||||
|
||||
# Analyzer breakdown
|
||||
analyzer_stmt = select(Alert.analyzer, func.count(Alert.id)).group_by(Alert.analyzer)
|
||||
if hunt_id:
|
||||
analyzer_stmt = analyzer_stmt.where(Alert.hunt_id == hunt_id)
|
||||
analyzer_rows = (await db.execute(analyzer_stmt)).all()
|
||||
analyzer_counts = {a: c for a, c in analyzer_rows}
|
||||
|
||||
# Top MITRE techniques
|
||||
mitre_stmt = (
|
||||
select(Alert.mitre_technique, func.count(Alert.id))
|
||||
.where(Alert.mitre_technique.isnot(None))
|
||||
.group_by(Alert.mitre_technique)
|
||||
.order_by(desc(func.count(Alert.id)))
|
||||
.limit(10)
|
||||
)
|
||||
if hunt_id:
|
||||
mitre_stmt = mitre_stmt.where(Alert.hunt_id == hunt_id)
|
||||
mitre_rows = (await db.execute(mitre_stmt)).all()
|
||||
top_mitre = [{"technique": t, "count": c} for t, c in mitre_rows]
|
||||
|
||||
total = sum(severity_counts.values())
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"severity_counts": severity_counts,
|
||||
"status_counts": status_counts,
|
||||
"analyzer_counts": analyzer_counts,
|
||||
"top_mitre": top_mitre,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{alert_id}", summary="Get alert detail")
|
||||
async def get_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.get(Alert, alert_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
return _alert_to_dict(result)
|
||||
|
||||
|
||||
@router.put("/{alert_id}", summary="Update alert (status, assignee, etc.)")
|
||||
async def update_alert(
|
||||
alert_id: str, body: AlertUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
|
||||
if body.status is not None:
|
||||
alert.status = body.status
|
||||
if body.status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if body.status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
if body.severity is not None:
|
||||
alert.severity = body.severity
|
||||
if body.assignee is not None:
|
||||
alert.assignee = body.assignee
|
||||
if body.case_id is not None:
|
||||
alert.case_id = body.case_id
|
||||
if body.tags is not None:
|
||||
alert.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(alert)
|
||||
return _alert_to_dict(alert)
|
||||
|
||||
|
||||
@router.delete("/{alert_id}", summary="Delete alert")
|
||||
async def delete_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
await db.delete(alert)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Bulk operations ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/bulk-update", summary="Bulk update alert statuses")
|
||||
async def bulk_update_alerts(
|
||||
alert_ids: list[str],
|
||||
status: str = Query(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
updated = 0
|
||||
for aid in alert_ids:
|
||||
alert = await db.get(Alert, aid)
|
||||
if alert:
|
||||
alert.status = status
|
||||
if status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
# ── Run Analyzers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/analyzers/list", summary="List available analyzers")
|
||||
async def list_analyzers():
|
||||
return {"analyzers": get_available_analyzers()}
|
||||
|
||||
|
||||
@router.post("/analyze", summary="Run analyzers on a dataset/hunt and optionally create alerts")
|
||||
async def run_analysis(
|
||||
request: AnalyzeRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db, dataset_id=request.dataset_id, hunt_id=request.hunt_id, limit=10000,
|
||||
)
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found")
|
||||
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Run analyzers
|
||||
candidates = await run_all_analyzers(rows, enabled=request.analyzers, config=request.config)
|
||||
|
||||
created_alerts: list[dict] = []
|
||||
if request.auto_create and candidates:
|
||||
for c in candidates:
|
||||
alert = Alert(
|
||||
id=_new_id(),
|
||||
title=c.title,
|
||||
description=c.description,
|
||||
severity=c.severity,
|
||||
analyzer=c.analyzer,
|
||||
score=c.score,
|
||||
evidence=c.evidence,
|
||||
mitre_technique=c.mitre_technique,
|
||||
tags=c.tags,
|
||||
hunt_id=request.hunt_id,
|
||||
dataset_id=request.dataset_id,
|
||||
)
|
||||
db.add(alert)
|
||||
created_alerts.append(_alert_to_dict(alert))
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"candidates_found": len(candidates),
|
||||
"alerts_created": len(created_alerts),
|
||||
"alerts": created_alerts,
|
||||
"summary": {
|
||||
"by_severity": _count_by(candidates, "severity"),
|
||||
"by_analyzer": _count_by(candidates, "analyzer"),
|
||||
"rows_analyzed": len(rows),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _count_by(items: list[AlertCandidate], attr: str) -> dict[str, int]:
|
||||
counts: dict[str, int] = {}
|
||||
for item in items:
|
||||
key = getattr(item, attr, "unknown")
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
# ── Alert Rules CRUD ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/rules/list", summary="List alert rules")
|
||||
async def list_rules(
|
||||
enabled: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(AlertRule)
|
||||
if enabled is not None:
|
||||
stmt = stmt.where(AlertRule.enabled == enabled)
|
||||
results = (await db.execute(stmt.order_by(AlertRule.created_at))).scalars().all()
|
||||
return {"rules": [_rule_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.post("/rules", summary="Create alert rule")
|
||||
async def create_rule(body: RuleCreate, db: AsyncSession = Depends(get_db)):
|
||||
# Validate analyzer exists
|
||||
if not get_analyzer(body.analyzer):
|
||||
raise HTTPException(status_code=400, detail=f"Unknown analyzer: {body.analyzer}")
|
||||
|
||||
rule = AlertRule(
|
||||
id=_new_id(),
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
analyzer=body.analyzer,
|
||||
config=body.config,
|
||||
severity_override=body.severity_override,
|
||||
enabled=body.enabled,
|
||||
hunt_id=body.hunt_id,
|
||||
)
|
||||
db.add(rule)
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.put("/rules/{rule_id}", summary="Update alert rule")
|
||||
async def update_rule(
|
||||
rule_id: str, body: RuleUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
|
||||
if body.name is not None:
|
||||
rule.name = body.name
|
||||
if body.description is not None:
|
||||
rule.description = body.description
|
||||
if body.config is not None:
|
||||
rule.config = body.config
|
||||
if body.severity_override is not None:
|
||||
rule.severity_override = body.severity_override
|
||||
if body.enabled is not None:
|
||||
rule.enabled = body.enabled
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", summary="Delete alert rule")
|
||||
async def delete_rule(rule_id: str, db: AsyncSession = Depends(get_db)):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
await db.delete(rule)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
295
backend/app/api/routes/analysis.py
Normal file
295
backend/app/api/routes/analysis.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.process_tree import (
|
||||
build_process_tree,
|
||||
build_storyline,
|
||||
compute_risk_scores,
|
||||
_fetch_rows,
|
||||
)
|
||||
from app.services.llm_analysis import (
|
||||
AnalysisRequest,
|
||||
AnalysisResult,
|
||||
run_llm_analysis,
|
||||
)
|
||||
from app.services.timeline import (
|
||||
build_timeline_bins,
|
||||
compute_field_stats,
|
||||
search_rows,
|
||||
)
|
||||
from app.services.mitre import (
|
||||
map_to_attack,
|
||||
build_knowledge_graph,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ProcessTreeResponse(BaseModel):
|
||||
trees: list[dict] = Field(default_factory=list)
|
||||
total_processes: int = 0
|
||||
|
||||
|
||||
class StorylineResponse(BaseModel):
|
||||
nodes: list[dict] = Field(default_factory=list)
|
||||
edges: list[dict] = Field(default_factory=list)
|
||||
summary: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RiskHostEntry(BaseModel):
|
||||
hostname: str
|
||||
score: int = 0
|
||||
signals: list[str] = Field(default_factory=list)
|
||||
event_count: int = 0
|
||||
process_count: int = 0
|
||||
network_count: int = 0
|
||||
file_count: int = 0
|
||||
|
||||
|
||||
class RiskSummaryResponse(BaseModel):
|
||||
hosts: list[RiskHostEntry] = Field(default_factory=list)
|
||||
overall_score: int = 0
|
||||
total_events: int = 0
|
||||
severity_breakdown: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/process-tree",
|
||||
response_model=ProcessTreeResponse,
|
||||
summary="Build process tree from dataset rows",
|
||||
description=(
|
||||
"Extracts parent→child process relationships from dataset rows "
|
||||
"and returns a hierarchical forest of process nodes."
|
||||
),
|
||||
)
|
||||
async def get_process_tree(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return process tree(s) for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
trees = await build_process_tree(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
|
||||
# Count total processes recursively
|
||||
def _count(node: dict) -> int:
|
||||
return 1 + sum(_count(c) for c in node.get("children", []))
|
||||
|
||||
total = sum(_count(t) for t in trees)
|
||||
|
||||
return ProcessTreeResponse(trees=trees, total_processes=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storyline",
|
||||
response_model=StorylineResponse,
|
||||
summary="Build CrowdStrike-style storyline attack graph",
|
||||
description=(
|
||||
"Creates a Cytoscape-compatible graph of events connected by "
|
||||
"process lineage (spawned) and temporal sequence within each host."
|
||||
),
|
||||
)
|
||||
async def get_storyline(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a storyline graph for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
result = await build_storyline(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
|
||||
return StorylineResponse(**result)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/risk-summary",
|
||||
response_model=RiskSummaryResponse,
|
||||
summary="Compute risk scores per host",
|
||||
description=(
|
||||
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
|
||||
"credential dumping, lateral movement) and produces per-host risk scores."
|
||||
),
|
||||
)
|
||||
async def get_risk_summary(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return risk scores for all hosts in a hunt."""
|
||||
result = await compute_risk_scores(db, hunt_id=hunt_id)
|
||||
return RiskSummaryResponse(**result)
|
||||
|
||||
|
||||
# ── LLM Analysis ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm-analyze",
|
||||
response_model=AnalysisResult,
|
||||
summary="Run LLM-powered threat analysis on dataset",
|
||||
description=(
|
||||
"Loads dataset rows server-side, builds a summary, and sends to "
|
||||
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
|
||||
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
|
||||
),
|
||||
)
|
||||
async def llm_analyze(
|
||||
request: AnalysisRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Run LLM analysis on a dataset or hunt."""
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
limit=2000,
|
||||
)
|
||||
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found for analysis")
|
||||
|
||||
# Extract data dicts
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Get dataset name
|
||||
ds_name = "hunt datasets"
|
||||
if request.dataset_id:
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(request.dataset_id)
|
||||
if ds:
|
||||
ds_name = ds.name
|
||||
|
||||
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
|
||||
return result
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/timeline",
|
||||
summary="Get event timeline histogram bins",
|
||||
)
|
||||
async def get_timeline(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
bins: int = Query(60, ge=10, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/field-stats",
|
||||
summary="Get per-field value distributions",
|
||||
)
|
||||
async def get_field_stats(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
fields: str | None = Query(None, description="Comma-separated field names"),
|
||||
top_n: int = Query(20, ge=5, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
field_list = [f.strip() for f in fields.split(",")] if fields else None
|
||||
return await compute_field_stats(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id,
|
||||
fields=field_list, top_n=top_n,
|
||||
)
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
query: str = ""
|
||||
filters: dict[str, str] = Field(default_factory=dict)
|
||||
time_start: Optional[str] = None
|
||||
time_end: Optional[str] = None
|
||||
limit: int = 500
|
||||
offset: int = 0
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Search and filter dataset rows",
|
||||
)
|
||||
async def search_dataset_rows(
|
||||
request: SearchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await search_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
time_start=request.time_start,
|
||||
time_end=request.time_end,
|
||||
limit=request.limit,
|
||||
offset=request.offset,
|
||||
)
|
||||
|
||||
|
||||
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/mitre-map",
|
||||
summary="Map dataset events to MITRE ATT&CK techniques",
|
||||
)
|
||||
async def get_mitre_map(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/knowledge-graph",
|
||||
summary="Build entity-technique knowledge graph",
|
||||
)
|
||||
async def get_knowledge_graph(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
296
backend/app/api/routes/cases.py
Normal file
296
backend/app/api/routes/cases.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""API routes for case management — CRUD for cases, tasks, and activity logs."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Case, CaseTask, ActivityLog, _new_id, _utcnow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/cases", tags=["cases"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
class CaseCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
severity: str = "medium"
|
||||
tlp: str = "amber"
|
||||
pap: str = "amber"
|
||||
priority: int = 2
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class CaseUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
tlp: Optional[str] = None
|
||||
pap: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
order: Optional[int] = None
|
||||
|
||||
|
||||
# ── Helper: log activity ─────────────────────────────────────────────
|
||||
|
||||
async def _log_activity(
|
||||
db: AsyncSession,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
action: str,
|
||||
details: dict | None = None,
|
||||
):
|
||||
log = ActivityLog(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
action=action,
|
||||
details=details,
|
||||
created_at=_utcnow(),
|
||||
)
|
||||
db.add(log)
|
||||
|
||||
|
||||
# ── Case CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", summary="Create a case")
|
||||
async def create_case(body: CaseCreate, db: AsyncSession = Depends(get_db)):
|
||||
now = _utcnow()
|
||||
case = Case(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
severity=body.severity,
|
||||
tlp=body.tlp,
|
||||
pap=body.pap,
|
||||
priority=body.priority,
|
||||
assignee=body.assignee,
|
||||
tags=body.tags,
|
||||
hunt_id=body.hunt_id,
|
||||
mitre_techniques=body.mitre_techniques,
|
||||
iocs=body.iocs,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(case)
|
||||
await _log_activity(db, "case", case.id, "created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.get("", summary="List cases")
|
||||
async def list_cases(
|
||||
status: Optional[str] = Query(None),
|
||||
hunt_id: Optional[str] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Case).order_by(desc(Case.updated_at))
|
||||
if status:
|
||||
q = q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
q = q.where(Case.hunt_id == hunt_id)
|
||||
q = q.offset(offset).limit(limit)
|
||||
result = await db.execute(q)
|
||||
cases = result.scalars().all()
|
||||
|
||||
count_q = select(func.count(Case.id))
|
||||
if status:
|
||||
count_q = count_q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
count_q = count_q.where(Case.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_q)).scalar() or 0
|
||||
|
||||
return {"cases": [_case_to_dict(c) for c in cases], "total": total}
|
||||
|
||||
|
||||
@router.get("/{case_id}", summary="Get case detail")
|
||||
async def get_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.put("/{case_id}", summary="Update a case")
|
||||
async def update_case(case_id: str, body: CaseUpdate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
changes = {}
|
||||
for field in ["title", "description", "severity", "tlp", "pap", "status",
|
||||
"priority", "assignee", "tags", "mitre_techniques", "iocs"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
old = getattr(case, field)
|
||||
setattr(case, field, val)
|
||||
changes[field] = {"old": old, "new": val}
|
||||
if "status" in changes and changes["status"]["new"] == "in-progress" and not case.started_at:
|
||||
case.started_at = _utcnow()
|
||||
if "status" in changes and changes["status"]["new"] in ("resolved", "closed"):
|
||||
case.resolved_at = _utcnow()
|
||||
case.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case.id, "updated", changes)
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.delete("/{case_id}", summary="Delete a case")
|
||||
async def delete_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
await db.delete(case)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Task CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{case_id}/tasks", summary="Add task to case")
|
||||
async def create_task(case_id: str, body: TaskCreate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
now = _utcnow()
|
||||
task = CaseTask(
|
||||
id=_new_id(),
|
||||
case_id=case_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
assignee=body.assignee,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(task)
|
||||
await _log_activity(db, "case", case_id, "task_created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.put("/{case_id}/tasks/{task_id}", summary="Update a task")
|
||||
async def update_task(case_id: str, task_id: str, body: TaskUpdate, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
for field in ["title", "description", "status", "assignee", "order"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
setattr(task, field, val)
|
||||
task.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case_id, "task_updated", {"task_id": task_id})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.delete("/{case_id}/tasks/{task_id}", summary="Delete a task")
|
||||
async def delete_task(case_id: str, task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Activity Log ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{case_id}/activity", summary="Get case activity log")
|
||||
async def get_activity(
|
||||
case_id: str,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = (
|
||||
select(ActivityLog)
|
||||
.where(ActivityLog.entity_type == "case", ActivityLog.entity_id == case_id)
|
||||
.order_by(desc(ActivityLog.created_at))
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(q)
|
||||
logs = result.scalars().all()
|
||||
return {
|
||||
"logs": [
|
||||
{
|
||||
"id": l.id,
|
||||
"action": l.action,
|
||||
"details": l.details,
|
||||
"user_id": l.user_id,
|
||||
"created_at": l.created_at.isoformat() if l.created_at else None,
|
||||
}
|
||||
for l in logs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _case_to_dict(c: Case) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"title": c.title,
|
||||
"description": c.description,
|
||||
"severity": c.severity,
|
||||
"tlp": c.tlp,
|
||||
"pap": c.pap,
|
||||
"status": c.status,
|
||||
"priority": c.priority,
|
||||
"assignee": c.assignee,
|
||||
"tags": c.tags or [],
|
||||
"hunt_id": c.hunt_id,
|
||||
"owner_id": c.owner_id,
|
||||
"mitre_techniques": c.mitre_techniques or [],
|
||||
"iocs": c.iocs or [],
|
||||
"started_at": c.started_at.isoformat() if c.started_at else None,
|
||||
"resolved_at": c.resolved_at.isoformat() if c.resolved_at else None,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
"tasks": [_task_to_dict(t) for t in (c.tasks or [])],
|
||||
}
|
||||
|
||||
|
||||
def _task_to_dict(t: CaseTask) -> dict:
|
||||
return {
|
||||
"id": t.id,
|
||||
"case_id": t.case_id,
|
||||
"title": t.title,
|
||||
"description": t.description,
|
||||
"status": t.status,
|
||||
"assignee": t.assignee,
|
||||
"order": t.order,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
|
||||
}
|
||||
@@ -293,3 +293,30 @@ async def delete_dataset(
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rescan-ioc",
|
||||
summary="Re-scan IOC columns for all datasets",
|
||||
)
|
||||
async def rescan_ioc_columns(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
|
||||
repo = DatasetRepository(db)
|
||||
all_ds = await repo.list_datasets(limit=10000)
|
||||
updated = 0
|
||||
for ds in all_ds:
|
||||
columns = list((ds.column_schema or {}).keys())
|
||||
if not columns:
|
||||
continue
|
||||
new_ioc = detect_ioc_columns(
|
||||
columns,
|
||||
ds.column_schema or {},
|
||||
ds.normalized_columns or {},
|
||||
)
|
||||
if new_ioc != (ds.ioc_columns or {}):
|
||||
ds.ioc_columns = new_ioc
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}
|
||||
|
||||
69
backend/app/api/routes/network.py
Normal file
69
backend/app/api/routes/network.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""API routes for Network Picture — deduplicated host inventory."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.network_inventory import build_network_picture
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HostEntry(BaseModel):
|
||||
hostname: str
|
||||
ips: list[str] = Field(default_factory=list)
|
||||
users: list[str] = Field(default_factory=list)
|
||||
os: list[str] = Field(default_factory=list)
|
||||
mac_addresses: list[str] = Field(default_factory=list)
|
||||
protocols: list[str] = Field(default_factory=list)
|
||||
open_ports: list[str] = Field(default_factory=list)
|
||||
remote_targets: list[str] = Field(default_factory=list)
|
||||
datasets: list[str] = Field(default_factory=list)
|
||||
connection_count: int = 0
|
||||
first_seen: str | None = None
|
||||
last_seen: str | None = None
|
||||
|
||||
|
||||
class PictureSummary(BaseModel):
|
||||
total_hosts: int = 0
|
||||
total_connections: int = 0
|
||||
total_unique_ips: int = 0
|
||||
datasets_scanned: int = 0
|
||||
|
||||
|
||||
class NetworkPictureResponse(BaseModel):
|
||||
hosts: list[HostEntry]
|
||||
summary: PictureSummary
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/picture",
|
||||
response_model=NetworkPictureResponse,
|
||||
summary="Build deduplicated host inventory for a hunt",
|
||||
description=(
|
||||
"Scans all datasets in the specified hunt, extracts host-identifying "
|
||||
"fields (hostname, IP, username, OS, MAC, ports), deduplicates by "
|
||||
"hostname, and returns a clean one-row-per-host network picture."
|
||||
),
|
||||
)
|
||||
async def get_network_picture(
|
||||
hunt_id: str = Query(..., description="Hunt ID to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a deduplicated network picture for a hunt."""
|
||||
if not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="hunt_id is required")
|
||||
|
||||
result = await build_network_picture(db, hunt_id)
|
||||
return result
|
||||
360
backend/app/api/routes/notebooks.py
Normal file
360
backend/app/api/routes/notebooks.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""API routes for investigation notebooks and playbooks."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Notebook, PlaybookRun, _new_id, _utcnow
|
||||
from app.services.playbook import (
|
||||
get_builtin_playbooks,
|
||||
get_playbook_template,
|
||||
validate_notebook_cells,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/notebooks", tags=["notebooks"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NotebookCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class NotebookUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class CellUpdate(BaseModel):
|
||||
"""Update a single cell or add a new one."""
|
||||
cell_id: str
|
||||
cell_type: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class PlaybookStart(BaseModel):
|
||||
playbook_name: str
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
started_by: Optional[str] = None
|
||||
|
||||
|
||||
class StepComplete(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
status: str = "completed" # completed | skipped
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _notebook_to_dict(nb: Notebook) -> dict:
|
||||
return {
|
||||
"id": nb.id,
|
||||
"title": nb.title,
|
||||
"description": nb.description,
|
||||
"cells": nb.cells or [],
|
||||
"hunt_id": nb.hunt_id,
|
||||
"case_id": nb.case_id,
|
||||
"owner_id": nb.owner_id,
|
||||
"tags": nb.tags or [],
|
||||
"cell_count": len(nb.cells or []),
|
||||
"created_at": nb.created_at.isoformat() if nb.created_at else None,
|
||||
"updated_at": nb.updated_at.isoformat() if nb.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _run_to_dict(run: PlaybookRun) -> dict:
|
||||
return {
|
||||
"id": run.id,
|
||||
"playbook_name": run.playbook_name,
|
||||
"status": run.status,
|
||||
"current_step": run.current_step,
|
||||
"total_steps": run.total_steps,
|
||||
"step_results": run.step_results or [],
|
||||
"hunt_id": run.hunt_id,
|
||||
"case_id": run.case_id,
|
||||
"started_by": run.started_by,
|
||||
"created_at": run.created_at.isoformat() if run.created_at else None,
|
||||
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
|
||||
"completed_at": run.completed_at.isoformat() if run.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Notebook CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List notebooks")
|
||||
async def list_notebooks(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Notebook)
|
||||
count_stmt = select(func.count(Notebook.id))
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Notebook.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Notebook.hunt_id == hunt_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Notebook.updated_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"notebooks": [_notebook_to_dict(n) for n in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/{notebook_id}", summary="Get notebook")
|
||||
async def get_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("", summary="Create notebook")
|
||||
async def create_notebook(body: NotebookCreate, db: AsyncSession = Depends(get_db)):
|
||||
cells = validate_notebook_cells(body.cells or [])
|
||||
if not cells:
|
||||
# Start with a default markdown cell
|
||||
cells = [{"id": "cell-0", "cell_type": "markdown", "source": "# Investigation Notes\n\nStart documenting your findings here.", "output": None, "metadata": {}}]
|
||||
|
||||
nb = Notebook(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
cells=cells,
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
tags=body.tags,
|
||||
)
|
||||
db.add(nb)
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.put("/{notebook_id}", summary="Update notebook")
|
||||
async def update_notebook(
|
||||
notebook_id: str, body: NotebookUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
if body.title is not None:
|
||||
nb.title = body.title
|
||||
if body.description is not None:
|
||||
nb.description = body.description
|
||||
if body.cells is not None:
|
||||
nb.cells = validate_notebook_cells(body.cells)
|
||||
if body.tags is not None:
|
||||
nb.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("/{notebook_id}/cells", summary="Add or update a cell")
|
||||
async def upsert_cell(
|
||||
notebook_id: str, body: CellUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = list(nb.cells or [])
|
||||
found = False
|
||||
for i, c in enumerate(cells):
|
||||
if c.get("id") == body.cell_id:
|
||||
if body.cell_type is not None:
|
||||
cells[i]["cell_type"] = body.cell_type
|
||||
if body.source is not None:
|
||||
cells[i]["source"] = body.source
|
||||
if body.output is not None:
|
||||
cells[i]["output"] = body.output
|
||||
if body.metadata is not None:
|
||||
cells[i]["metadata"] = body.metadata
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
cells.append({
|
||||
"id": body.cell_id,
|
||||
"cell_type": body.cell_type or "markdown",
|
||||
"source": body.source or "",
|
||||
"output": body.output,
|
||||
"metadata": body.metadata or {},
|
||||
})
|
||||
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}/cells/{cell_id}", summary="Delete a cell")
|
||||
async def delete_cell(
|
||||
notebook_id: str, cell_id: str, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = [c for c in (nb.cells or []) if c.get("id") != cell_id]
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
return {"ok": True, "remaining_cells": len(cells)}
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}", summary="Delete notebook")
|
||||
async def delete_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
await db.delete(nb)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Playbooks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/playbooks/templates", summary="List built-in playbook templates")
|
||||
async def list_playbook_templates():
|
||||
templates = get_builtin_playbooks()
|
||||
return {
|
||||
"templates": [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"category": t["category"],
|
||||
"tags": t["tags"],
|
||||
"step_count": len(t["steps"]),
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/playbooks/templates/{name}", summary="Get playbook template detail")
|
||||
async def get_playbook_template_detail(name: str):
|
||||
template = get_playbook_template(name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.post("/playbooks/start", summary="Start a playbook run")
|
||||
async def start_playbook(body: PlaybookStart, db: AsyncSession = Depends(get_db)):
|
||||
template = get_playbook_template(body.playbook_name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
|
||||
run = PlaybookRun(
|
||||
id=_new_id(),
|
||||
playbook_name=body.playbook_name,
|
||||
status="in-progress",
|
||||
current_step=1,
|
||||
total_steps=len(template["steps"]),
|
||||
step_results=[],
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
started_by=body.started_by,
|
||||
)
|
||||
db.add(run)
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.get("/playbooks/runs", summary="List playbook runs")
|
||||
async def list_playbook_runs(
|
||||
status: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(PlaybookRun)
|
||||
if status:
|
||||
stmt = stmt.where(PlaybookRun.status == status)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(PlaybookRun.hunt_id == hunt_id)
|
||||
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(PlaybookRun.created_at)).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"runs": [_run_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.get("/playbooks/runs/{run_id}", summary="Get playbook run detail")
|
||||
async def get_playbook_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
# Also include the template steps
|
||||
template = get_playbook_template(run.playbook_name)
|
||||
result = _run_to_dict(run)
|
||||
result["steps"] = template["steps"] if template else []
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/complete-step", summary="Complete current playbook step")
|
||||
async def complete_step(
|
||||
run_id: str, body: StepComplete, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
if run.status != "in-progress":
|
||||
raise HTTPException(status_code=400, detail="Run is not in progress")
|
||||
|
||||
step_results = list(run.step_results or [])
|
||||
step_results.append({
|
||||
"step": run.current_step,
|
||||
"status": body.status,
|
||||
"notes": body.notes,
|
||||
"completed_at": _utcnow().isoformat(),
|
||||
})
|
||||
run.step_results = step_results
|
||||
|
||||
if run.current_step >= run.total_steps:
|
||||
run.status = "completed"
|
||||
run.completed_at = _utcnow()
|
||||
else:
|
||||
run.current_step += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/abort", summary="Abort a playbook run")
|
||||
async def abort_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
run.status = "aborted"
|
||||
run.completed_at = _utcnow()
|
||||
await db.commit()
|
||||
return _run_to_dict(run)
|
||||
@@ -15,7 +15,7 @@ class AppConfig(BaseSettings):
|
||||
|
||||
# ── General ────────────────────────────────────────────────────────
|
||||
APP_NAME: str = "ThreatHunt"
|
||||
APP_VERSION: str = "0.3.0"
|
||||
APP_VERSION: str = "0.4.0"
|
||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# ── Database ───────────────────────────────────────────────────────
|
||||
|
||||
@@ -45,8 +45,20 @@ async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# Only create tables that don't already exist (safe alongside Alembic)
|
||||
def _create_missing(sync_conn):
|
||||
inspector = sa_inspect(sync_conn)
|
||||
existing = set(inspector.get_table_names())
|
||||
tables_to_create = [
|
||||
t for t in Base.metadata.sorted_tables
|
||||
if t.name not in existing
|
||||
]
|
||||
Base.metadata.create_all(sync_conn, tables=tables_to_create)
|
||||
|
||||
await conn.run_sync(_create_missing)
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
|
||||
@@ -326,3 +326,221 @@ class Keyword(Base):
|
||||
Index("ix_keywords_theme", "theme_id"),
|
||||
Index("ix_keywords_value", "value"),
|
||||
)
|
||||
|
||||
|
||||
# ── Cases ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Case(Base):
|
||||
"""Incident / investigation case, inspired by TheHive."""
|
||||
__tablename__ = "cases"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
severity: Mapped[str] = mapped_column(String(16), default="medium") # info|low|medium|high|critical
|
||||
tlp: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
|
||||
pap: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
|
||||
status: Mapped[str] = mapped_column(String(24), default="open") # open|in-progress|resolved|closed
|
||||
priority: Mapped[int] = mapped_column(Integer, default=2) # 1(urgent)..4(low)
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
iocs: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{type, value, description}]
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
tasks: Mapped[list["CaseTask"]] = relationship(
|
||||
back_populates="case", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_cases_hunt", "hunt_id"),
|
||||
Index("ix_cases_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class CaseTask(Base):
|
||||
"""Task within a case (Kanban board item)."""
|
||||
__tablename__ = "case_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
case_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("cases.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(24), default="todo") # todo|in-progress|done
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
case: Mapped["Case"] = relationship(back_populates="tasks")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_case_tasks_case", "case_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Activity Log ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActivityLog(Base):
|
||||
"""Audit trail / activity log for cases and hunts."""
|
||||
__tablename__ = "activity_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
entity_type: Mapped[str] = mapped_column(String(32), nullable=False) # case|hunt|annotation
|
||||
entity_id: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
action: Mapped[str] = mapped_column(String(64), nullable=False) # created|updated|status_changed|etc
|
||||
details: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_activity_entity", "entity_type", "entity_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Alerts ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
"""Security alert generated by analyzers or rules."""
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
severity: Mapped[str] = mapped_column(String(16), default="medium") # critical|high|medium|low|info
|
||||
status: Mapped[str] = mapped_column(String(24), default="new") # new|acknowledged|in-progress|resolved|false-positive
|
||||
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # which analyzer produced it
|
||||
score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
evidence: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{row_index, field, value, ...}]
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
acknowledged_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alerts_severity", "severity"),
|
||||
Index("ix_alerts_status", "status"),
|
||||
Index("ix_alerts_hunt", "hunt_id"),
|
||||
Index("ix_alerts_dataset", "dataset_id"),
|
||||
)
|
||||
|
||||
|
||||
class AlertRule(Base):
|
||||
"""User-defined alert rule (triggers analyzers automatically on upload)."""
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # analyzer name
|
||||
config: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # analyzer config overrides
|
||||
severity_override: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
) # None = global
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_rules_analyzer", "analyzer"),
|
||||
)
|
||||
|
||||
|
||||
# ── Notebooks ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Notebook(Base):
|
||||
"""Investigation notebook — cell-based document for analyst notes and queries."""
|
||||
__tablename__ = "notebooks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
cells: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{id, cell_type, source, output, metadata}]
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_notebooks_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Playbook Runs ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PlaybookRun(Base):
|
||||
"""Record of a playbook execution (links a template to a hunt/case)."""
|
||||
__tablename__ = "playbook_runs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
playbook_name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(24), default="in-progress") # in-progress | completed | aborted
|
||||
current_step: Mapped[int] = mapped_column(Integer, default=1)
|
||||
total_steps: Mapped[int] = mapped_column(Integer, default=0)
|
||||
step_results: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{step, status, notes, completed_at}]
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
started_by: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_playbook_runs_hunt", "hunt_id"),
|
||||
Index("ix_playbook_runs_status", "status"),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,11 @@ from app.api.routes.correlation import router as correlation_router
|
||||
from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.network import router as network_router
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.cases import router as cases_router
|
||||
from app.api.routes.alerts import router as alerts_router
|
||||
from app.api.routes.notebooks import router as notebooks_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -74,6 +79,11 @@ app.include_router(enrichment_router)
|
||||
app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(network_router)
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(cases_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(notebooks_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
|
||||
464
backend/app/services/analyzers.py
Normal file
464
backend/app/services/analyzers.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Pluggable Analyzer Framework for ThreatHunt.
|
||||
|
||||
Each analyzer implements a simple protocol:
|
||||
- name / description properties
|
||||
- async analyze(rows, config) -> list[AlertCandidate]
|
||||
|
||||
The AnalyzerRegistry discovers and runs all enabled analyzers against
|
||||
a dataset, producing alert candidates that the alert system can persist.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Alert Candidate DTO ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertCandidate:
|
||||
"""A single finding from an analyzer, before it becomes a persisted Alert."""
|
||||
analyzer: str
|
||||
title: str
|
||||
severity: str # critical | high | medium | low | info
|
||||
description: str
|
||||
evidence: list[dict] = field(default_factory=list) # [{row_index, field, value, ...}]
|
||||
mitre_technique: Optional[str] = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
score: float = 0.0 # 0-100
|
||||
|
||||
|
||||
# ── Base Analyzer ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BaseAnalyzer(ABC):
|
||||
"""Interface every analyzer must implement."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze(
|
||||
self, rows: list[dict[str, Any]], config: dict[str, Any] | None = None
|
||||
) -> list[AlertCandidate]: ...
|
||||
|
||||
|
||||
# ── Built-in Analyzers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class EntropyAnalyzer(BaseAnalyzer):
|
||||
"""Detects high-entropy strings (encoded payloads, obfuscated commands)."""
|
||||
|
||||
name = "entropy"
|
||||
description = "Flags fields with high Shannon entropy (possible encoding/obfuscation)"
|
||||
|
||||
ENTROPY_FIELDS = [
|
||||
"command_line", "commandline", "process_command_line", "cmdline",
|
||||
"powershell_command", "script_block", "url", "uri", "path",
|
||||
"file_path", "target_filename", "query", "dns_query",
|
||||
]
|
||||
DEFAULT_THRESHOLD = 4.5
|
||||
|
||||
@staticmethod
|
||||
def _shannon(s: str) -> float:
|
||||
if not s or len(s) < 8:
|
||||
return 0.0
|
||||
freq = Counter(s)
|
||||
length = len(s)
|
||||
return -sum((c / length) * math.log2(c / length) for c in freq.values())
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
threshold = config.get("entropy_threshold", self.DEFAULT_THRESHOLD)
|
||||
min_length = config.get("min_length", 20)
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
for field_name in self.ENTROPY_FIELDS:
|
||||
val = str(row.get(field_name, ""))
|
||||
if len(val) < min_length:
|
||||
continue
|
||||
ent = self._shannon(val)
|
||||
if ent >= threshold:
|
||||
sev = "critical" if ent > 5.5 else "high" if ent > 5.0 else "medium"
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"High-entropy string in {field_name}",
|
||||
severity=sev,
|
||||
description=f"Shannon entropy {ent:.2f} (threshold {threshold}) in row {idx}, field '{field_name}'",
|
||||
evidence=[{"row_index": idx, "field": field_name, "value": val[:200], "entropy": round(ent, 3)}],
|
||||
mitre_technique="T1027", # Obfuscated Files or Information
|
||||
tags=["obfuscation", "entropy"],
|
||||
score=min(100, ent * 18),
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
class SuspiciousCommandAnalyzer(BaseAnalyzer):
|
||||
"""Detects known-bad command patterns (credential dumping, lateral movement, persistence)."""
|
||||
|
||||
name = "suspicious_commands"
|
||||
description = "Flags processes executing known-suspicious command patterns"
|
||||
|
||||
PATTERNS: list[tuple[str, str, str, str]] = [
|
||||
# (regex, title, severity, mitre_technique)
|
||||
(r"mimikatz|sekurlsa|lsadump|kerberos::list", "Mimikatz / Credential Dumping", "critical", "T1003"),
|
||||
(r"(?i)-enc\s+[A-Za-z0-9+/=]{40,}", "Encoded PowerShell command", "high", "T1059.001"),
|
||||
(r"(?i)invoke-(mimikatz|expression|webrequest|shellcode)", "Suspicious PowerShell Invoke", "high", "T1059.001"),
|
||||
(r"(?i)net\s+(user|localgroup|group)\s+/add", "Local account creation", "high", "T1136.001"),
|
||||
(r"(?i)schtasks\s+/create", "Scheduled task creation", "medium", "T1053.005"),
|
||||
(r"(?i)reg\s+add\s+.*\\run", "Registry Run key persistence", "high", "T1547.001"),
|
||||
(r"(?i)wmic\s+.*(process\s+call|shadowcopy\s+delete)", "WMI abuse / shadow copy deletion", "critical", "T1047"),
|
||||
(r"(?i)psexec|winrm|wmic\s+/node:", "Lateral movement tool", "high", "T1021"),
|
||||
(r"(?i)certutil\s+-urlcache", "Certutil download (LOLBin)", "high", "T1105"),
|
||||
(r"(?i)bitsadmin\s+/transfer", "BITSAdmin download", "medium", "T1197"),
|
||||
(r"(?i)vssadmin\s+delete\s+shadows", "VSS shadow deletion (ransomware)", "critical", "T1490"),
|
||||
(r"(?i)bcdedit.*recoveryenabled.*no", "Boot config tamper (ransomware)", "critical", "T1490"),
|
||||
(r"(?i)attrib\s+\+h\s+\+s", "Hidden file attribute set", "low", "T1564.001"),
|
||||
(r"(?i)netsh\s+advfirewall\s+.*disable", "Firewall disabled", "high", "T1562.004"),
|
||||
(r"(?i)whoami\s*/priv", "Privilege enumeration", "medium", "T1033"),
|
||||
(r"(?i)nltest\s+/dclist", "Domain controller enumeration", "medium", "T1018"),
|
||||
(r"(?i)dsquery|ldapsearch|adfind", "Active Directory enumeration", "medium", "T1087.002"),
|
||||
(r"(?i)procdump.*-ma\s+lsass", "LSASS memory dump", "critical", "T1003.001"),
|
||||
(r"(?i)rundll32.*comsvcs.*MiniDump", "LSASS dump via comsvcs", "critical", "T1003.001"),
|
||||
]
|
||||
|
||||
CMD_FIELDS = [
|
||||
"command_line", "commandline", "process_command_line", "cmdline",
|
||||
"parent_command_line", "powershell_command",
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
alerts: list[AlertCandidate] = []
|
||||
compiled = [(re.compile(p, re.IGNORECASE), t, s, m) for p, t, s, m in self.PATTERNS]
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
for fld in self.CMD_FIELDS:
|
||||
val = str(row.get(fld, ""))
|
||||
if len(val) < 3:
|
||||
continue
|
||||
for pattern, title, sev, mitre in compiled:
|
||||
if pattern.search(val):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=title,
|
||||
severity=sev,
|
||||
description=f"Suspicious command pattern in row {idx}: {val[:200]}",
|
||||
evidence=[{"row_index": idx, "field": fld, "value": val[:300]}],
|
||||
mitre_technique=mitre,
|
||||
tags=["command", "suspicious"],
|
||||
score={"critical": 95, "high": 80, "medium": 60, "low": 30}.get(sev, 50),
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
class NetworkAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects anomalous network patterns (beaconing, unusual ports, large transfers)."""
|
||||
|
||||
name = "network_anomaly"
|
||||
description = "Flags anomalous network behavior (beaconing, unusual ports, large transfers)"
|
||||
|
||||
SUSPICIOUS_PORTS = {4444, 5555, 6666, 8888, 9999, 1234, 31337, 12345, 54321, 1337}
|
||||
C2_PORTS = {443, 8443, 8080, 4443, 9443}
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
# Track destination IP frequency for beaconing detection
|
||||
dst_freq: dict[str, list[int]] = defaultdict(list)
|
||||
port_hits: list[tuple[int, str, int]] = []
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
dst_ip = str(row.get("dst_ip", row.get("destination_ip", row.get("dest_ip", ""))))
|
||||
dst_port = row.get("dst_port", row.get("destination_port", row.get("dest_port", "")))
|
||||
|
||||
if dst_ip and dst_ip != "":
|
||||
dst_freq[dst_ip].append(idx)
|
||||
|
||||
if dst_port:
|
||||
try:
|
||||
port_num = int(dst_port)
|
||||
if port_num in self.SUSPICIOUS_PORTS:
|
||||
port_hits.append((idx, dst_ip, port_num))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Large transfer detection
|
||||
bytes_val = row.get("bytes_sent", row.get("bytes_out", row.get("sent_bytes", 0)))
|
||||
try:
|
||||
if int(bytes_val or 0) > config.get("large_transfer_threshold", 10_000_000):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title="Large data transfer detected",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: {bytes_val} bytes sent to {dst_ip}",
|
||||
evidence=[{"row_index": idx, "dst_ip": dst_ip, "bytes": str(bytes_val)}],
|
||||
mitre_technique="T1048",
|
||||
tags=["exfiltration", "network"],
|
||||
score=65,
|
||||
))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Beaconing: IPs contacted more than threshold times
|
||||
beacon_thresh = config.get("beacon_threshold", 20)
|
||||
for ip, indices in dst_freq.items():
|
||||
if len(indices) >= beacon_thresh:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Possible beaconing to {ip}",
|
||||
severity="high",
|
||||
description=f"Destination {ip} contacted {len(indices)} times (threshold: {beacon_thresh})",
|
||||
evidence=[{"dst_ip": ip, "contact_count": len(indices), "sample_rows": indices[:10]}],
|
||||
mitre_technique="T1071",
|
||||
tags=["beaconing", "c2", "network"],
|
||||
score=min(95, 50 + len(indices)),
|
||||
))
|
||||
|
||||
# Suspicious ports
|
||||
for idx, ip, port in port_hits:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Connection on suspicious port {port}",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: connection to {ip}:{port}",
|
||||
evidence=[{"row_index": idx, "dst_ip": ip, "dst_port": port}],
|
||||
mitre_technique="T1571",
|
||||
tags=["suspicious_port", "network"],
|
||||
score=55,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class FrequencyAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects statistically rare values that may indicate anomalies."""
|
||||
|
||||
name = "frequency_anomaly"
|
||||
description = "Flags statistically rare field values (potential anomalies)"
|
||||
|
||||
FIELDS_TO_CHECK = [
|
||||
"process_name", "image_name", "parent_process_name",
|
||||
"user", "username", "user_name",
|
||||
"event_type", "action", "status",
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
rarity_threshold = config.get("rarity_threshold", 0.01) # <1% occurrence
|
||||
min_rows = config.get("min_rows", 50)
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
if len(rows) < min_rows:
|
||||
return alerts
|
||||
|
||||
for fld in self.FIELDS_TO_CHECK:
|
||||
values = [str(row.get(fld, "")) for row in rows if row.get(fld)]
|
||||
if not values:
|
||||
continue
|
||||
counts = Counter(values)
|
||||
total = len(values)
|
||||
|
||||
for val, cnt in counts.items():
|
||||
pct = cnt / total
|
||||
if pct <= rarity_threshold and cnt <= 3:
|
||||
# Find row indices
|
||||
indices = [i for i, r in enumerate(rows) if str(r.get(fld, "")) == val]
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Rare {fld}: {val[:80]}",
|
||||
severity="low",
|
||||
description=f"'{val}' appears {cnt}/{total} times ({pct:.2%}) in field '{fld}'",
|
||||
evidence=[{"field": fld, "value": val[:200], "count": cnt, "total": total, "rows": indices[:5]}],
|
||||
tags=["anomaly", "rare"],
|
||||
score=max(20, 50 - (pct * 5000)),
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class AuthAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects authentication anomalies (brute force, unusual logon types)."""
|
||||
|
||||
name = "auth_anomaly"
|
||||
description = "Flags authentication anomalies (failed logins, unusual logon types)"
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
# Track failed logins per user
|
||||
failed_by_user: dict[str, list[int]] = defaultdict(list)
|
||||
logon_types: dict[str, list[int]] = defaultdict(list)
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
event_type = str(row.get("event_type", row.get("action", ""))).lower()
|
||||
status = str(row.get("status", row.get("result", ""))).lower()
|
||||
user = str(row.get("username", row.get("user", row.get("user_name", ""))))
|
||||
logon_type = str(row.get("logon_type", ""))
|
||||
|
||||
if "logon" in event_type or "auth" in event_type or "login" in event_type:
|
||||
if "fail" in status or "4625" in str(row.get("event_id", "")):
|
||||
if user:
|
||||
failed_by_user[user].append(idx)
|
||||
|
||||
if logon_type in ("3", "10"): # Network/RemoteInteractive
|
||||
logon_types[logon_type].append(idx)
|
||||
|
||||
# Brute force: >5 failed logins for same user
|
||||
brute_thresh = config.get("brute_force_threshold", 5)
|
||||
for user, indices in failed_by_user.items():
|
||||
if len(indices) >= brute_thresh:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Possible brute force: {user}",
|
||||
severity="high",
|
||||
description=f"User '{user}' had {len(indices)} failed logins",
|
||||
evidence=[{"user": user, "failed_count": len(indices), "rows": indices[:10]}],
|
||||
mitre_technique="T1110",
|
||||
tags=["brute_force", "authentication"],
|
||||
score=min(90, 50 + len(indices) * 3),
|
||||
))
|
||||
|
||||
# Unusual logon types
|
||||
for ltype, indices in logon_types.items():
|
||||
label = "Network logon (Type 3)" if ltype == "3" else "Remote Desktop (Type 10)"
|
||||
if len(indices) >= 3:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"{label} detected",
|
||||
severity="medium" if ltype == "3" else "high",
|
||||
description=f"{len(indices)} {label} events detected",
|
||||
evidence=[{"logon_type": ltype, "count": len(indices), "rows": indices[:10]}],
|
||||
mitre_technique="T1021",
|
||||
tags=["authentication", "lateral_movement"],
|
||||
score=55 if ltype == "3" else 70,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class PersistenceAnalyzer(BaseAnalyzer):
|
||||
"""Detects persistence mechanisms (registry keys, services, scheduled tasks)."""
|
||||
|
||||
name = "persistence"
|
||||
description = "Flags persistence mechanism installations"
|
||||
|
||||
REGISTRY_PATTERNS = [
|
||||
(r"(?i)\\CurrentVersion\\Run", "Run key persistence", "T1547.001"),
|
||||
(r"(?i)\\Services\\", "Service installation", "T1543.003"),
|
||||
(r"(?i)\\Winlogon\\", "Winlogon persistence", "T1547.004"),
|
||||
(r"(?i)\\Image File Execution Options\\", "IFEO debugger persistence", "T1546.012"),
|
||||
(r"(?i)\\Explorer\\Shell Folders", "Shell folder hijack", "T1547.001"),
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
alerts: list[AlertCandidate] = []
|
||||
compiled = [(re.compile(p), t, m) for p, t, m in self.REGISTRY_PATTERNS]
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
# Check registry paths
|
||||
reg_path = str(row.get("registry_key", row.get("target_object", row.get("registry_path", ""))))
|
||||
for pattern, title, mitre in compiled:
|
||||
if pattern.search(reg_path):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=title,
|
||||
severity="high",
|
||||
description=f"Row {idx}: {reg_path[:200]}",
|
||||
evidence=[{"row_index": idx, "registry_key": reg_path[:300]}],
|
||||
mitre_technique=mitre,
|
||||
tags=["persistence", "registry"],
|
||||
score=75,
|
||||
))
|
||||
|
||||
# Check for service creation events
|
||||
event_type = str(row.get("event_type", "")).lower()
|
||||
if "service" in event_type and "creat" in event_type:
|
||||
svc_name = row.get("service_name", row.get("target_filename", "unknown"))
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Service created: {svc_name}",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: New service '{svc_name}' created",
|
||||
evidence=[{"row_index": idx, "service_name": str(svc_name)}],
|
||||
mitre_technique="T1543.003",
|
||||
tags=["persistence", "service"],
|
||||
score=60,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ── Analyzer Registry ────────────────────────────────────────────────
|
||||
|
||||
|
||||
_ALL_ANALYZERS: list[BaseAnalyzer] = [
|
||||
EntropyAnalyzer(),
|
||||
SuspiciousCommandAnalyzer(),
|
||||
NetworkAnomalyAnalyzer(),
|
||||
FrequencyAnomalyAnalyzer(),
|
||||
AuthAnomalyAnalyzer(),
|
||||
PersistenceAnalyzer(),
|
||||
]
|
||||
|
||||
|
||||
def get_available_analyzers() -> list[dict[str, str]]:
|
||||
"""Return metadata about all registered analyzers."""
|
||||
return [{"name": a.name, "description": a.description} for a in _ALL_ANALYZERS]
|
||||
|
||||
|
||||
def get_analyzer(name: str) -> BaseAnalyzer | None:
|
||||
"""Get an analyzer by name."""
|
||||
for a in _ALL_ANALYZERS:
|
||||
if a.name == name:
|
||||
return a
|
||||
return None
|
||||
|
||||
|
||||
async def run_all_analyzers(
|
||||
rows: list[dict[str, Any]],
|
||||
enabled: list[str] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> list[AlertCandidate]:
|
||||
"""Run all (or selected) analyzers and return combined alert candidates.
|
||||
|
||||
Args:
|
||||
rows: Flat list of row dicts (normalized_data or data from DatasetRow).
|
||||
enabled: Optional list of analyzer names to run. Runs all if None.
|
||||
config: Optional config overrides passed to each analyzer.
|
||||
|
||||
Returns:
|
||||
Combined list of AlertCandidate from all analyzers, sorted by score desc.
|
||||
"""
|
||||
config = config or {}
|
||||
results: list[AlertCandidate] = []
|
||||
|
||||
for analyzer in _ALL_ANALYZERS:
|
||||
if enabled and analyzer.name not in enabled:
|
||||
continue
|
||||
try:
|
||||
candidates = await analyzer.analyze(rows, config)
|
||||
results.extend(candidates)
|
||||
logger.info("Analyzer %s produced %d alerts", analyzer.name, len(candidates))
|
||||
except Exception:
|
||||
logger.exception("Analyzer %s failed", analyzer.name)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda a: a.score, reverse=True)
|
||||
return results
|
||||
322
backend/app/services/llm_analysis.py
Normal file
322
backend/app/services/llm_analysis.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""LLM-powered dataset analysis — replaces manual IOC enrichment.
|
||||
|
||||
Loads dataset rows server-side, builds a concise summary, and sends it
|
||||
to Wile (70B heavy) or Roadrunner (fast) for threat analysis.
|
||||
Supports both single-dataset and hunt-wide analysis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.agents.providers_v2 import OllamaProvider
|
||||
from app.agents.router import TaskType, task_router
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AnalysisRequest(BaseModel):
|
||||
"""Request for LLM-powered analysis of a dataset."""
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
question: str = Field(
|
||||
default="Perform a comprehensive threat analysis of this dataset. "
|
||||
"Identify anomalies, suspicious patterns, potential IOCs, and recommend "
|
||||
"next steps for the analyst.",
|
||||
description="Specific question or general analysis request",
|
||||
)
|
||||
mode: str = Field(default="deep", description="quick | deep")
|
||||
focus: Optional[str] = Field(
|
||||
None,
|
||||
description="Focus area: threats, anomalies, lateral_movement, exfil, persistence, recon",
|
||||
)
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
"""LLM analysis result."""
|
||||
analysis: str = Field(..., description="Full analysis text (markdown)")
|
||||
confidence: float = Field(default=0.0, description="0-1 confidence")
|
||||
key_findings: list[str] = Field(default_factory=list)
|
||||
iocs_identified: list[dict] = Field(default_factory=list)
|
||||
recommended_actions: list[str] = Field(default_factory=list)
|
||||
mitre_techniques: list[str] = Field(default_factory=list)
|
||||
risk_score: int = Field(default=0, description="0-100 risk score")
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
rows_analyzed: int = 0
|
||||
dataset_summary: str = ""
|
||||
|
||||
|
||||
# ── Analysis prompts ──────────────────────────────────────────────────
|
||||
|
||||
ANALYSIS_SYSTEM = """You are an expert threat hunter and incident response analyst.
|
||||
You are analyzing CSV log data from forensic tools (Velociraptor, Sysmon, etc.).
|
||||
|
||||
Your task: Perform deep threat analysis of the data provided and produce actionable findings.
|
||||
|
||||
RESPOND WITH VALID JSON ONLY:
|
||||
{
|
||||
"analysis": "Detailed markdown analysis with headers and bullet points",
|
||||
"confidence": 0.85,
|
||||
"key_findings": ["Finding 1", "Finding 2"],
|
||||
"iocs_identified": [{"type": "ip", "value": "1.2.3.4", "context": "C2 traffic"}],
|
||||
"recommended_actions": ["Action 1", "Action 2"],
|
||||
"mitre_techniques": ["T1059.001 - PowerShell", "T1071 - Application Layer Protocol"],
|
||||
"risk_score": 65
|
||||
}
|
||||
"""
|
||||
|
||||
FOCUS_PROMPTS = {
|
||||
"threats": "Focus on identifying active threats, malware indicators, and attack patterns.",
|
||||
"anomalies": "Focus on statistical anomalies, outliers, and unusual behavior patterns.",
|
||||
"lateral_movement": "Focus on evidence of lateral movement: PsExec, WMI, RDP, SMB, pass-the-hash.",
|
||||
"exfil": "Focus on data exfiltration indicators: large transfers, DNS tunneling, unusual destinations.",
|
||||
"persistence": "Focus on persistence mechanisms: scheduled tasks, services, registry, startup items.",
|
||||
"recon": "Focus on reconnaissance activity: scanning, enumeration, discovery commands.",
|
||||
}
|
||||
|
||||
|
||||
# ── Data summarizer ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def summarize_dataset_rows(
|
||||
rows: list[dict],
|
||||
columns: list[str] | None = None,
|
||||
max_sample: int = 20,
|
||||
max_chars: int = 6000,
|
||||
) -> str:
|
||||
"""Build a concise text summary of dataset rows for LLM consumption.
|
||||
|
||||
Includes:
|
||||
- Column headers and types
|
||||
- Statistical summary (unique values, top values per column)
|
||||
- Sample rows (first N)
|
||||
- Detected patterns of interest
|
||||
"""
|
||||
if not rows:
|
||||
return "Empty dataset — no rows to analyze."
|
||||
|
||||
cols = columns or list(rows[0].keys())
|
||||
n_rows = len(rows)
|
||||
|
||||
parts: list[str] = []
|
||||
parts.append(f"## Dataset Summary: {n_rows} rows, {len(cols)} columns")
|
||||
parts.append(f"Columns: {', '.join(cols)}")
|
||||
|
||||
# Per-column stats
|
||||
parts.append("\n### Column Statistics:")
|
||||
for col in cols[:30]: # limit to first 30 cols
|
||||
values = [str(r.get(col, "")) for r in rows if r.get(col) not in (None, "", "N/A")]
|
||||
if not values:
|
||||
continue
|
||||
unique = len(set(values))
|
||||
counter = Counter(values)
|
||||
top3 = counter.most_common(3)
|
||||
top_str = ", ".join(f"{v} ({c}x)" for v, c in top3)
|
||||
parts.append(f"- **{col}**: {len(values)} non-null, {unique} unique. Top: {top_str}")
|
||||
|
||||
# Sample rows
|
||||
sample = rows[:max_sample]
|
||||
parts.append(f"\n### Sample Rows (first {len(sample)}):")
|
||||
for i, row in enumerate(sample):
|
||||
row_str = " | ".join(f"{k}={v}" for k, v in row.items() if v not in (None, "", "N/A"))
|
||||
parts.append(f"{i+1}. {row_str}")
|
||||
|
||||
# Detect interesting patterns
|
||||
patterns: list[str] = []
|
||||
all_cmds = [str(r.get("command_line", "")).lower() for r in rows if r.get("command_line")]
|
||||
sus_cmds = [c for c in all_cmds if any(
|
||||
k in c for k in ["powershell -enc", "certutil", "bitsadmin", "mshta",
|
||||
"regsvr32", "invoke-", "mimikatz", "psexec"]
|
||||
)]
|
||||
if sus_cmds:
|
||||
patterns.append(f"⚠️ {len(sus_cmds)} suspicious command lines detected")
|
||||
|
||||
all_ips = [str(r.get("dst_ip", "")) for r in rows if r.get("dst_ip")]
|
||||
ext_ips = [ip for ip in all_ips if ip and not ip.startswith(("10.", "192.168.", "172.", "127."))]
|
||||
if ext_ips:
|
||||
unique_ext = len(set(ext_ips))
|
||||
patterns.append(f"🌐 {unique_ext} unique external destination IPs")
|
||||
|
||||
if patterns:
|
||||
parts.append("\n### Detected Patterns:")
|
||||
for p in patterns:
|
||||
parts.append(f"- {p}")
|
||||
|
||||
text = "\n".join(parts)
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars] + "\n... (truncated)"
|
||||
return text
|
||||
|
||||
|
||||
# ── LLM analysis engine ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_llm_analysis(
|
||||
rows: list[dict],
|
||||
request: AnalysisRequest,
|
||||
dataset_name: str = "unknown",
|
||||
) -> AnalysisResult:
|
||||
"""Run LLM analysis on dataset rows."""
|
||||
start = time.monotonic()
|
||||
|
||||
# Build summary
|
||||
summary = summarize_dataset_rows(rows)
|
||||
|
||||
# Route to appropriate model
|
||||
task_type = TaskType.DEEP_ANALYSIS if request.mode == "deep" else TaskType.QUICK_CHAT
|
||||
decision = task_router.route(task_type)
|
||||
|
||||
# Build prompt
|
||||
focus_text = FOCUS_PROMPTS.get(request.focus or "", "")
|
||||
prompt = f"""Analyze the following forensic dataset from '{dataset_name}'.
|
||||
|
||||
{focus_text}
|
||||
|
||||
Analyst question: {request.question}
|
||||
|
||||
{summary}
|
||||
"""
|
||||
|
||||
# Enrich with SANS RAG
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
request.question,
|
||||
investigation_context=f"Analyzing {len(rows)} rows from {dataset_name}",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = task_router.get_provider(decision)
|
||||
messages = [
|
||||
{"role": "system", "content": ANALYSIS_SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
provider.generate(
|
||||
prompt=prompt,
|
||||
system=ANALYSIS_SYSTEM,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS * 2, # longer for analysis
|
||||
temperature=0.3,
|
||||
),
|
||||
timeout=300, # 5 min hard limit
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("LLM analysis timed out after 300s")
|
||||
return AnalysisResult(
|
||||
analysis="Analysis timed out after 5 minutes. Try a smaller dataset or 'quick' mode.",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM analysis failed: {e}")
|
||||
return AnalysisResult(
|
||||
analysis=f"Analysis failed: {str(e)}",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse JSON response
|
||||
result = _parse_analysis(raw)
|
||||
result.model_used = decision.model
|
||||
result.node_used = decision.node
|
||||
result.latency_ms = elapsed
|
||||
result.rows_analyzed = len(rows)
|
||||
result.dataset_summary = summary
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_analysis(raw) -> AnalysisResult:
|
||||
"""Try to parse LLM output as JSON, fall back to plain text.
|
||||
|
||||
raw may be:
|
||||
- A dict from OllamaProvider.generate() with key "response" containing LLM text
|
||||
- A plain string from other providers
|
||||
"""
|
||||
# Ollama provider returns {"response": "<llm text>", "model": ..., ...}
|
||||
if isinstance(raw, dict):
|
||||
text = raw.get("response") or raw.get("analysis") or str(raw)
|
||||
logger.info(f"_parse_analysis: extracted text from dict, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
else:
|
||||
text = str(raw)
|
||||
logger.info(f"_parse_analysis: raw is str, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Strip markdown code fences
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
lines = [l for l in lines if not l.strip().startswith("```")]
|
||||
text = "\n".join(lines).strip()
|
||||
|
||||
# Try direct JSON parse first
|
||||
for candidate in _extract_json_candidates(text):
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
logger.info(f"_parse_analysis: parsed JSON OK, keys={list(data.keys())}")
|
||||
return AnalysisResult(
|
||||
analysis=data.get("analysis", text),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
key_findings=data.get("key_findings", []),
|
||||
iocs_identified=data.get("iocs_identified", []),
|
||||
recommended_actions=data.get("recommended_actions", []),
|
||||
mitre_techniques=data.get("mitre_techniques", []),
|
||||
risk_score=int(data.get("risk_score", 0)),
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"_parse_analysis: JSON parse failed: {e}, candidate len={len(candidate)}, first 100: {candidate[:100]}")
|
||||
continue
|
||||
|
||||
# Fallback: plain text
|
||||
logger.warning(f"_parse_analysis: all JSON parse attempts failed, falling back to plain text")
|
||||
return AnalysisResult(
|
||||
analysis=text,
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_candidates(text: str):
|
||||
"""Yield JSON candidate strings from text, trying progressively more aggressive extraction."""
|
||||
import re
|
||||
|
||||
# 1. The whole text as-is
|
||||
yield text
|
||||
|
||||
# 2. Find outermost { ... } block
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
block = text[start:end + 1]
|
||||
yield block
|
||||
|
||||
# 3. Try to fix common LLM JSON issues:
|
||||
# - trailing commas before ] or }
|
||||
fixed = re.sub(r',\s*([}\]])', r'\1', block)
|
||||
if fixed != block:
|
||||
yield fixed
|
||||
484
backend/app/services/mitre.py
Normal file
484
backend/app/services/mitre.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
MITRE ATT&CK mapping service.
|
||||
|
||||
Maps dataset events to ATT&CK techniques using pattern-based heuristics.
|
||||
Uses the enterprise-attack matrix (embedded patterns for offline use).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── ATT&CK Technique Patterns ────────────────────────────────────────
|
||||
# Subset of enterprise-attack techniques with detection patterns.
|
||||
# Each entry: (technique_id, name, tactic, patterns_list)
|
||||
|
||||
TECHNIQUE_PATTERNS: list[tuple[str, str, str, list[str]]] = [
|
||||
# Initial Access
|
||||
("T1566", "Phishing", "initial-access", [
|
||||
r"phish", r"\.hta\b", r"\.lnk\b", r"mshta\.exe", r"outlook.*attachment",
|
||||
]),
|
||||
("T1190", "Exploit Public-Facing Application", "initial-access", [
|
||||
r"exploit", r"CVE-\d{4}", r"vulnerability", r"webshell",
|
||||
]),
|
||||
|
||||
# Execution
|
||||
("T1059.001", "PowerShell", "execution", [
|
||||
r"powershell", r"pwsh", r"-enc\b", r"-encodedcommand",
|
||||
r"invoke-expression", r"iex\b", r"bypass\b.*execution",
|
||||
]),
|
||||
("T1059.003", "Windows Command Shell", "execution", [
|
||||
r"cmd\.exe", r"/c\s+", r"command\.com",
|
||||
]),
|
||||
("T1059.005", "Visual Basic", "execution", [
|
||||
r"wscript", r"cscript", r"\.vbs\b", r"\.vbe\b",
|
||||
]),
|
||||
("T1047", "Windows Management Instrumentation", "execution", [
|
||||
r"wmic\b", r"winmgmt", r"wmi\b",
|
||||
]),
|
||||
("T1053.005", "Scheduled Task", "execution", [
|
||||
r"schtasks", r"at\.exe", r"taskschd",
|
||||
]),
|
||||
("T1204", "User Execution", "execution", [
|
||||
r"user.*click", r"open.*attachment", r"macro",
|
||||
]),
|
||||
|
||||
# Persistence
|
||||
("T1547.001", "Registry Run Keys", "persistence", [
|
||||
r"CurrentVersion\\Run", r"HKLM\\Software\\Microsoft\\Windows\\CurrentVersion\\Run",
|
||||
r"reg\s+add.*\\Run",
|
||||
]),
|
||||
("T1543.003", "Windows Service", "persistence", [
|
||||
r"sc\s+create", r"new-service", r"service.*install",
|
||||
]),
|
||||
("T1136", "Create Account", "persistence", [
|
||||
r"net\s+user\s+/add", r"new-localuser", r"useradd",
|
||||
]),
|
||||
("T1053.005", "Scheduled Task/Job", "persistence", [
|
||||
r"schtasks\s+/create", r"crontab",
|
||||
]),
|
||||
|
||||
# Privilege Escalation
|
||||
("T1548.002", "Bypass User Access Control", "privilege-escalation", [
|
||||
r"eventvwr", r"fodhelper", r"uac.*bypass", r"computerdefaults",
|
||||
]),
|
||||
("T1134", "Access Token Manipulation", "privilege-escalation", [
|
||||
r"token.*impersonat", r"runas", r"adjusttokenprivileges",
|
||||
]),
|
||||
|
||||
# Defense Evasion
|
||||
("T1070.001", "Clear Windows Event Logs", "defense-evasion", [
|
||||
r"wevtutil\s+cl", r"clear-eventlog", r"clearlog",
|
||||
]),
|
||||
("T1562.001", "Disable or Modify Tools", "defense-evasion", [
|
||||
r"tamper.*protection", r"disable.*defender", r"set-mppreference",
|
||||
r"disable.*firewall",
|
||||
]),
|
||||
("T1027", "Obfuscated Files or Information", "defense-evasion", [
|
||||
r"base64", r"-enc\b", r"certutil.*-decode", r"frombase64",
|
||||
]),
|
||||
("T1036", "Masquerading", "defense-evasion", [
|
||||
r"rename.*\.exe", r"masquerad", r"svchost.*unusual",
|
||||
]),
|
||||
("T1055", "Process Injection", "defense-evasion", [
|
||||
r"inject", r"createremotethread", r"ntcreatethreadex",
|
||||
r"virtualalloc", r"writeprocessmemory",
|
||||
]),
|
||||
|
||||
# Credential Access
|
||||
("T1003.001", "LSASS Memory", "credential-access", [
|
||||
r"mimikatz", r"sekurlsa", r"lsass", r"procdump.*lsass",
|
||||
]),
|
||||
("T1003.003", "NTDS", "credential-access", [
|
||||
r"ntds\.dit", r"vssadmin.*shadow", r"ntdsutil",
|
||||
]),
|
||||
("T1110", "Brute Force", "credential-access", [
|
||||
r"brute.*force", r"failed.*login.*\d{3,}", r"hydra", r"medusa",
|
||||
]),
|
||||
("T1558.003", "Kerberoasting", "credential-access", [
|
||||
r"kerberoast", r"invoke-kerberoast", r"GetUserSPNs",
|
||||
]),
|
||||
|
||||
# Discovery
|
||||
("T1087", "Account Discovery", "discovery", [
|
||||
r"net\s+user", r"net\s+localgroup", r"get-aduser",
|
||||
]),
|
||||
("T1082", "System Information Discovery", "discovery", [
|
||||
r"systeminfo", r"hostname", r"ver\b",
|
||||
]),
|
||||
("T1083", "File and Directory Discovery", "discovery", [
|
||||
r"dir\s+/s", r"tree\s+/f", r"get-childitem.*-recurse",
|
||||
]),
|
||||
("T1057", "Process Discovery", "discovery", [
|
||||
r"tasklist", r"get-process", r"ps\s+aux",
|
||||
]),
|
||||
("T1018", "Remote System Discovery", "discovery", [
|
||||
r"net\s+view", r"ping\s+-", r"arp\s+-a", r"nslookup",
|
||||
]),
|
||||
("T1016", "System Network Configuration Discovery", "discovery", [
|
||||
r"ipconfig", r"ifconfig", r"netstat",
|
||||
]),
|
||||
|
||||
# Lateral Movement
|
||||
("T1021.001", "Remote Desktop Protocol", "lateral-movement", [
|
||||
r"rdp\b", r"mstsc", r"3389", r"remote\s+desktop",
|
||||
]),
|
||||
("T1021.002", "SMB/Windows Admin Shares", "lateral-movement", [
|
||||
r"\\\\.*\\(c|admin)\$", r"psexec", r"smbclient", r"net\s+use",
|
||||
]),
|
||||
("T1021.006", "Windows Remote Management", "lateral-movement", [
|
||||
r"winrm", r"enter-pssession", r"invoke-command.*-computername",
|
||||
r"wsman", r"5985|5986",
|
||||
]),
|
||||
("T1570", "Lateral Tool Transfer", "lateral-movement", [
|
||||
r"copy.*\\\\", r"xcopy.*\\\\", r"robocopy",
|
||||
]),
|
||||
|
||||
# Collection
|
||||
("T1560", "Archive Collected Data", "collection", [
|
||||
r"compress-archive", r"7z\.exe", r"rar\s+a", r"tar\s+-[cz]",
|
||||
]),
|
||||
("T1005", "Data from Local System", "collection", [
|
||||
r"type\s+.*password", r"findstr.*password", r"select-string.*credential",
|
||||
]),
|
||||
|
||||
# Command and Control
|
||||
("T1071.001", "Web Protocols", "command-and-control", [
|
||||
r"http[s]?://\d+\.\d+\.\d+\.\d+", r"curl\b", r"wget\b",
|
||||
r"invoke-webrequest", r"beacon",
|
||||
]),
|
||||
("T1573", "Encrypted Channel", "command-and-control", [
|
||||
r"ssl\b", r"tls\b", r"encrypted.*tunnel", r"stunnel",
|
||||
]),
|
||||
("T1105", "Ingress Tool Transfer", "command-and-control", [
|
||||
r"certutil.*-urlcache", r"bitsadmin.*transfer",
|
||||
r"downloadfile", r"invoke-webrequest.*-outfile",
|
||||
]),
|
||||
("T1219", "Remote Access Software", "command-and-control", [
|
||||
r"teamviewer", r"anydesk", r"logmein", r"vnc",
|
||||
]),
|
||||
|
||||
# Exfiltration
|
||||
("T1048", "Exfiltration Over Alternative Protocol", "exfiltration", [
|
||||
r"dns.*tunnel", r"exfil", r"icmp.*tunnel",
|
||||
]),
|
||||
("T1041", "Exfiltration Over C2 Channel", "exfiltration", [
|
||||
r"upload.*c2", r"exfil.*http",
|
||||
]),
|
||||
("T1567", "Exfiltration Over Web Service", "exfiltration", [
|
||||
r"mega\.nz", r"dropbox", r"pastebin", r"transfer\.sh",
|
||||
]),
|
||||
|
||||
# Impact
|
||||
("T1486", "Data Encrypted for Impact", "impact", [
|
||||
r"ransomware", r"encrypt.*files", r"\.locked\b", r"ransom",
|
||||
]),
|
||||
("T1489", "Service Stop", "impact", [
|
||||
r"sc\s+stop", r"net\s+stop", r"stop-service",
|
||||
]),
|
||||
("T1529", "System Shutdown/Reboot", "impact", [
|
||||
r"shutdown\s+/[rs]", r"restart-computer",
|
||||
]),
|
||||
]
|
||||
|
||||
# Tactic display names and kill-chain order
|
||||
TACTIC_ORDER = [
|
||||
"initial-access", "execution", "persistence", "privilege-escalation",
|
||||
"defense-evasion", "credential-access", "discovery", "lateral-movement",
|
||||
"collection", "command-and-control", "exfiltration", "impact",
|
||||
]
|
||||
TACTIC_NAMES = {
|
||||
"initial-access": "Initial Access",
|
||||
"execution": "Execution",
|
||||
"persistence": "Persistence",
|
||||
"privilege-escalation": "Privilege Escalation",
|
||||
"defense-evasion": "Defense Evasion",
|
||||
"credential-access": "Credential Access",
|
||||
"discovery": "Discovery",
|
||||
"lateral-movement": "Lateral Movement",
|
||||
"collection": "Collection",
|
||||
"command-and-control": "Command and Control",
|
||||
"exfiltration": "Exfiltration",
|
||||
"impact": "Impact",
|
||||
}
|
||||
|
||||
|
||||
# ── Row fetcher ───────────────────────────────────────────────────────
|
||||
|
||||
async def _fetch_rows(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
limit: int = 5000,
|
||||
) -> list[dict[str, Any]]:
|
||||
q = select(DatasetRow).join(Dataset)
|
||||
if dataset_id:
|
||||
q = q.where(DatasetRow.dataset_id == dataset_id)
|
||||
elif hunt_id:
|
||||
q = q.where(Dataset.hunt_id == hunt_id)
|
||||
q = q.limit(limit)
|
||||
result = await db.execute(q)
|
||||
return [r.data for r in result.scalars().all()]
|
||||
|
||||
|
||||
# ── Main functions ────────────────────────────────────────────────────
|
||||
|
||||
async def map_to_attack(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Map dataset rows to MITRE ATT&CK techniques.
|
||||
Returns a matrix-style structure + evidence list.
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id, hunt_id)
|
||||
if not rows:
|
||||
return {"tactics": [], "techniques": [], "evidence": [], "coverage": {}, "total_rows": 0}
|
||||
|
||||
# Flatten all string values per row for matching
|
||||
row_texts: list[str] = []
|
||||
for row in rows:
|
||||
parts = []
|
||||
for v in row.values():
|
||||
if v is not None:
|
||||
parts.append(str(v).lower())
|
||||
row_texts.append(" ".join(parts))
|
||||
|
||||
# Match techniques
|
||||
technique_hits: dict[str, list[dict]] = defaultdict(list) # tech_id -> evidence rows
|
||||
technique_meta: dict[str, tuple[str, str]] = {} # tech_id -> (name, tactic)
|
||||
row_techniques: list[set[str]] = [set() for _ in rows]
|
||||
|
||||
for tech_id, tech_name, tactic, patterns in TECHNIQUE_PATTERNS:
|
||||
compiled = [re.compile(p, re.IGNORECASE) for p in patterns]
|
||||
technique_meta[tech_id] = (tech_name, tactic)
|
||||
for i, text in enumerate(row_texts):
|
||||
for pat in compiled:
|
||||
if pat.search(text):
|
||||
row_techniques[i].add(tech_id)
|
||||
if len(technique_hits[tech_id]) < 10: # limit evidence
|
||||
# find matching field
|
||||
matched_field = ""
|
||||
matched_value = ""
|
||||
for k, v in rows[i].items():
|
||||
if v and pat.search(str(v).lower()):
|
||||
matched_field = k
|
||||
matched_value = str(v)[:200]
|
||||
break
|
||||
technique_hits[tech_id].append({
|
||||
"row_index": i,
|
||||
"field": matched_field,
|
||||
"value": matched_value,
|
||||
"pattern": pat.pattern,
|
||||
})
|
||||
break # one pattern match per technique per row is enough
|
||||
|
||||
# Build tactic → technique structure
|
||||
tactic_techniques: dict[str, list[dict]] = defaultdict(list)
|
||||
for tech_id, evidence_list in technique_hits.items():
|
||||
name, tactic = technique_meta[tech_id]
|
||||
tactic_techniques[tactic].append({
|
||||
"id": tech_id,
|
||||
"name": name,
|
||||
"count": len(evidence_list),
|
||||
"evidence": evidence_list[:5],
|
||||
})
|
||||
|
||||
# Build ordered tactics list
|
||||
tactics = []
|
||||
for tactic_key in TACTIC_ORDER:
|
||||
techs = tactic_techniques.get(tactic_key, [])
|
||||
tactics.append({
|
||||
"id": tactic_key,
|
||||
"name": TACTIC_NAMES.get(tactic_key, tactic_key),
|
||||
"techniques": sorted(techs, key=lambda t: -t["count"]),
|
||||
"total_hits": sum(t["count"] for t in techs),
|
||||
})
|
||||
|
||||
# Coverage stats
|
||||
covered_tactics = sum(1 for t in tactics if t["total_hits"] > 0)
|
||||
total_technique_hits = sum(t["total_hits"] for t in tactics)
|
||||
|
||||
return {
|
||||
"tactics": tactics,
|
||||
"coverage": {
|
||||
"tactics_covered": covered_tactics,
|
||||
"tactics_total": len(TACTIC_ORDER),
|
||||
"techniques_matched": len(technique_hits),
|
||||
"total_evidence": total_technique_hits,
|
||||
},
|
||||
"total_rows": len(rows),
|
||||
}
|
||||
|
||||
|
||||
async def build_knowledge_graph(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build a knowledge graph connecting entities (hosts, users, processes, IPs)
|
||||
to MITRE techniques and tactics.
|
||||
Returns Cytoscape-compatible nodes + edges.
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id, hunt_id)
|
||||
if not rows:
|
||||
return {"nodes": [], "edges": [], "stats": {}}
|
||||
|
||||
# Extract entities
|
||||
entities: dict[str, set[str]] = defaultdict(set) # type -> set of values
|
||||
row_entity_map: list[list[tuple[str, str]]] = [] # per-row list of (type, value)
|
||||
|
||||
# Field name patterns for entity extraction
|
||||
HOST_FIELDS = re.compile(r"hostname|computer|host|machine", re.I)
|
||||
USER_FIELDS = re.compile(r"user|account|logon.*name|subject.*name", re.I)
|
||||
IP_FIELDS = re.compile(r"src.*ip|dst.*ip|ip.*addr|source.*ip|dest.*ip|remote.*addr", re.I)
|
||||
PROC_FIELDS = re.compile(r"process.*name|image|parent.*image|executable|command", re.I)
|
||||
|
||||
for row in rows:
|
||||
row_ents: list[tuple[str, str]] = []
|
||||
for k, v in row.items():
|
||||
if not v or str(v).strip() in ('', '-', 'N/A', 'None'):
|
||||
continue
|
||||
val = str(v).strip()
|
||||
if HOST_FIELDS.search(k):
|
||||
entities["host"].add(val)
|
||||
row_ents.append(("host", val))
|
||||
elif USER_FIELDS.search(k):
|
||||
entities["user"].add(val)
|
||||
row_ents.append(("user", val))
|
||||
elif IP_FIELDS.search(k):
|
||||
entities["ip"].add(val)
|
||||
row_ents.append(("ip", val))
|
||||
elif PROC_FIELDS.search(k):
|
||||
# Clean process name
|
||||
pname = val.split("\\")[-1].split("/")[-1][:60]
|
||||
entities["process"].add(pname)
|
||||
row_ents.append(("process", pname))
|
||||
row_entity_map.append(row_ents)
|
||||
|
||||
# Map rows to techniques
|
||||
row_texts = [" ".join(str(v).lower() for v in row.values() if v) for row in rows]
|
||||
row_techniques: list[set[str]] = [set() for _ in rows]
|
||||
tech_meta: dict[str, tuple[str, str]] = {}
|
||||
|
||||
for tech_id, tech_name, tactic, patterns in TECHNIQUE_PATTERNS:
|
||||
compiled = [re.compile(p, re.I) for p in patterns]
|
||||
tech_meta[tech_id] = (tech_name, tactic)
|
||||
for i, text in enumerate(row_texts):
|
||||
for pat in compiled:
|
||||
if pat.search(text):
|
||||
row_techniques[i].add(tech_id)
|
||||
break
|
||||
|
||||
# Build graph
|
||||
nodes: list[dict] = []
|
||||
edges: list[dict] = []
|
||||
node_ids: set[str] = set()
|
||||
edge_counter: Counter = Counter()
|
||||
|
||||
# Entity nodes
|
||||
TYPE_COLORS = {
|
||||
"host": "#3b82f6",
|
||||
"user": "#10b981",
|
||||
"ip": "#8b5cf6",
|
||||
"process": "#f59e0b",
|
||||
"technique": "#ef4444",
|
||||
"tactic": "#6366f1",
|
||||
}
|
||||
TYPE_SHAPES = {
|
||||
"host": "roundrectangle",
|
||||
"user": "ellipse",
|
||||
"ip": "diamond",
|
||||
"process": "hexagon",
|
||||
"technique": "tag",
|
||||
"tactic": "round-rectangle",
|
||||
}
|
||||
|
||||
for ent_type, values in entities.items():
|
||||
for val in list(values)[:50]: # limit nodes
|
||||
nid = f"{ent_type}:{val}"
|
||||
if nid not in node_ids:
|
||||
node_ids.add(nid)
|
||||
nodes.append({
|
||||
"data": {
|
||||
"id": nid,
|
||||
"label": val[:40],
|
||||
"type": ent_type,
|
||||
"color": TYPE_COLORS.get(ent_type, "#666"),
|
||||
"shape": TYPE_SHAPES.get(ent_type, "ellipse"),
|
||||
},
|
||||
})
|
||||
|
||||
# Technique nodes
|
||||
seen_techniques: set[str] = set()
|
||||
for tech_set in row_techniques:
|
||||
seen_techniques.update(tech_set)
|
||||
|
||||
for tech_id in seen_techniques:
|
||||
name, tactic = tech_meta.get(tech_id, (tech_id, "unknown"))
|
||||
nid = f"technique:{tech_id}"
|
||||
if nid not in node_ids:
|
||||
node_ids.add(nid)
|
||||
nodes.append({
|
||||
"data": {
|
||||
"id": nid,
|
||||
"label": f"{tech_id}\n{name}",
|
||||
"type": "technique",
|
||||
"color": TYPE_COLORS["technique"],
|
||||
"shape": TYPE_SHAPES["technique"],
|
||||
"tactic": tactic,
|
||||
},
|
||||
})
|
||||
|
||||
# Edges: entity → technique (based on co-occurrence in rows)
|
||||
for i, row_ents in enumerate(row_entity_map):
|
||||
for ent_type, ent_val in row_ents:
|
||||
for tech_id in row_techniques[i]:
|
||||
src = f"{ent_type}:{ent_val}"
|
||||
tgt = f"technique:{tech_id}"
|
||||
if src in node_ids and tgt in node_ids:
|
||||
edge_key = (src, tgt)
|
||||
edge_counter[edge_key] += 1
|
||||
|
||||
# Edges: entity → entity (based on co-occurrence)
|
||||
for row_ents in row_entity_map:
|
||||
for j in range(len(row_ents)):
|
||||
for k in range(j + 1, len(row_ents)):
|
||||
src = f"{row_ents[j][0]}:{row_ents[j][1]}"
|
||||
tgt = f"{row_ents[k][0]}:{row_ents[k][1]}"
|
||||
if src in node_ids and tgt in node_ids and src != tgt:
|
||||
edge_counter[(src, tgt)] += 1
|
||||
|
||||
# Build edge list (filter low-weight edges)
|
||||
for (src, tgt), weight in edge_counter.most_common(500):
|
||||
if weight < 1:
|
||||
continue
|
||||
edges.append({
|
||||
"data": {
|
||||
"source": src,
|
||||
"target": tgt,
|
||||
"weight": weight,
|
||||
"label": str(weight) if weight > 2 else "",
|
||||
},
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": {
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"entity_counts": {t: len(v) for t, v in entities.items()},
|
||||
"techniques_found": len(seen_techniques),
|
||||
},
|
||||
}
|
||||
252
backend/app/services/network_inventory.py
Normal file
252
backend/app/services/network_inventory.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Network Picture — deduplicated host inventory built from dataset rows.
|
||||
|
||||
Scans all datasets in a hunt, extracts host-identifying fields from
|
||||
normalized data, and groups by hostname (or src_ip fallback) to produce
|
||||
a clean one-row-per-host inventory. Uses sets for deduplication —
|
||||
if an IP appears 900 times, it shows once.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Canonical column names we extract per row
|
||||
_HOST_KEYS = ("hostname",)
|
||||
_IP_KEYS = ("src_ip", "dst_ip", "ip_address")
|
||||
_USER_KEYS = ("username",)
|
||||
_OS_KEYS = ("os",)
|
||||
_MAC_KEYS = ("mac_address",)
|
||||
_PORT_SRC_KEYS = ("src_port",)
|
||||
_PORT_DST_KEYS = ("dst_port",)
|
||||
_PROTO_KEYS = ("protocol",)
|
||||
_STATE_KEYS = ("connection_state",)
|
||||
_TS_KEYS = ("timestamp",)
|
||||
|
||||
# Junk values to skip
|
||||
_JUNK = frozenset({"", "-", "0.0.0.0", "::", "0", "127.0.0.1", "::1", "localhost", "unknown", "n/a", "none", "null"})
|
||||
|
||||
ROW_BATCH = 1000 # rows fetched per DB query
|
||||
MAX_HOSTS = 1000 # hard cap on returned hosts
|
||||
|
||||
|
||||
def _clean(val: Any) -> str:
|
||||
"""Normalise a cell value to a clean string or empty."""
|
||||
s = (val if isinstance(val, str) else str(val) if val is not None else "").strip()
|
||||
return "" if s.lower() in _JUNK else s
|
||||
|
||||
|
||||
def _try_parse_ts(val: str) -> datetime | None:
|
||||
"""Best-effort timestamp parse (subset of common formats)."""
|
||||
for fmt in (
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ",
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
):
|
||||
try:
|
||||
return datetime.strptime(val.strip(), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
class _HostBucket:
|
||||
"""Mutable accumulator for a single host."""
|
||||
|
||||
__slots__ = (
|
||||
"hostname", "ips", "users", "os_versions", "mac_addresses",
|
||||
"protocols", "open_ports", "remote_targets", "datasets",
|
||||
"connection_count", "first_seen", "last_seen",
|
||||
)
|
||||
|
||||
def __init__(self, hostname: str):
|
||||
self.hostname = hostname
|
||||
self.ips: set[str] = set()
|
||||
self.users: set[str] = set()
|
||||
self.os_versions: set[str] = set()
|
||||
self.mac_addresses: set[str] = set()
|
||||
self.protocols: set[str] = set()
|
||||
self.open_ports: set[str] = set()
|
||||
self.remote_targets: set[str] = set()
|
||||
self.datasets: set[str] = set()
|
||||
self.connection_count: int = 0
|
||||
self.first_seen: datetime | None = None
|
||||
self.last_seen: datetime | None = None
|
||||
|
||||
def ingest(self, row: dict[str, Any], ds_name: str) -> None:
|
||||
"""Merge one normalised row into this bucket."""
|
||||
self.connection_count += 1
|
||||
self.datasets.add(ds_name)
|
||||
|
||||
for k in _IP_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
self.ips.add(v)
|
||||
|
||||
for k in _USER_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
self.users.add(v)
|
||||
|
||||
for k in _OS_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
self.os_versions.add(v)
|
||||
|
||||
for k in _MAC_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
self.mac_addresses.add(v)
|
||||
|
||||
for k in _PROTO_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
self.protocols.add(v.upper())
|
||||
|
||||
# Open ports = local (src) ports
|
||||
for k in _PORT_SRC_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v and v != "0":
|
||||
self.open_ports.add(v)
|
||||
|
||||
# Remote targets = dst IPs
|
||||
dst = _clean(row.get("dst_ip"))
|
||||
if dst:
|
||||
self.remote_targets.add(dst)
|
||||
|
||||
# Timestamps
|
||||
for k in _TS_KEYS:
|
||||
v = _clean(row.get(k))
|
||||
if v:
|
||||
ts = _try_parse_ts(v)
|
||||
if ts:
|
||||
if self.first_seen is None or ts < self.first_seen:
|
||||
self.first_seen = ts
|
||||
if self.last_seen is None or ts > self.last_seen:
|
||||
self.last_seen = ts
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"hostname": self.hostname,
|
||||
"ips": sorted(self.ips),
|
||||
"users": sorted(self.users),
|
||||
"os": sorted(self.os_versions),
|
||||
"mac_addresses": sorted(self.mac_addresses),
|
||||
"protocols": sorted(self.protocols),
|
||||
"open_ports": sorted(self.open_ports, key=lambda p: int(p) if p.isdigit() else 0),
|
||||
"remote_targets": sorted(self.remote_targets),
|
||||
"datasets": sorted(self.datasets),
|
||||
"connection_count": self.connection_count,
|
||||
"first_seen": self.first_seen.isoformat() if self.first_seen else None,
|
||||
"last_seen": self.last_seen.isoformat() if self.last_seen else None,
|
||||
}
|
||||
|
||||
|
||||
async def build_network_picture(
|
||||
db: AsyncSession,
|
||||
hunt_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a deduplicated host inventory for all datasets in a hunt.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"hosts": [ {hostname, ips[], users[], os[], ...}, ... ],
|
||||
"summary": { total_hosts, total_connections, total_unique_ips, datasets_scanned }
|
||||
}
|
||||
"""
|
||||
# 1. Get all datasets in this hunt
|
||||
ds_result = await db.execute(
|
||||
select(Dataset)
|
||||
.where(Dataset.hunt_id == hunt_id)
|
||||
.order_by(Dataset.created_at)
|
||||
)
|
||||
ds_list: Sequence[Dataset] = ds_result.scalars().all()
|
||||
|
||||
if not ds_list:
|
||||
return {
|
||||
"hosts": [],
|
||||
"summary": {
|
||||
"total_hosts": 0,
|
||||
"total_connections": 0,
|
||||
"total_unique_ips": 0,
|
||||
"datasets_scanned": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# 2. Stream rows and aggregate into host buckets
|
||||
buckets: dict[str, _HostBucket] = {} # key = uppercase hostname or IP
|
||||
|
||||
for ds in ds_list:
|
||||
ds_name = ds.name or ds.filename
|
||||
offset = 0
|
||||
while True:
|
||||
stmt = (
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(ROW_BATCH)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows: Sequence[DatasetRow] = result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for dr in rows:
|
||||
norm = dr.normalized_data or dr.data or {}
|
||||
|
||||
# Determine grouping key: hostname preferred, else src_ip/ip_address
|
||||
host_val = ""
|
||||
for k in _HOST_KEYS:
|
||||
host_val = _clean(norm.get(k))
|
||||
if host_val:
|
||||
break
|
||||
if not host_val:
|
||||
for k in ("src_ip", "ip_address"):
|
||||
host_val = _clean(norm.get(k))
|
||||
if host_val:
|
||||
break
|
||||
if not host_val:
|
||||
# Row has no host identifier — skip
|
||||
continue
|
||||
|
||||
bucket_key = host_val.upper()
|
||||
if bucket_key not in buckets:
|
||||
buckets[bucket_key] = _HostBucket(host_val)
|
||||
|
||||
buckets[bucket_key].ingest(norm, ds_name)
|
||||
|
||||
offset += ROW_BATCH
|
||||
|
||||
# 3. Convert to sorted list (by connection count descending)
|
||||
hosts_raw = sorted(buckets.values(), key=lambda b: b.connection_count, reverse=True)
|
||||
if len(hosts_raw) > MAX_HOSTS:
|
||||
hosts_raw = hosts_raw[:MAX_HOSTS]
|
||||
|
||||
hosts = [b.to_dict() for b in hosts_raw]
|
||||
|
||||
# 4. Summary stats
|
||||
all_ips: set[str] = set()
|
||||
total_conns = 0
|
||||
for b in hosts_raw:
|
||||
all_ips.update(b.ips)
|
||||
total_conns += b.connection_count
|
||||
|
||||
return {
|
||||
"hosts": hosts,
|
||||
"summary": {
|
||||
"total_hosts": len(hosts),
|
||||
"total_connections": total_conns,
|
||||
"total_unique_ips": len(all_ips),
|
||||
"datasets_scanned": len(ds_list),
|
||||
},
|
||||
}
|
||||
@@ -23,12 +23,12 @@ COLUMN_MAPPINGS: list[tuple[str, str]] = [
|
||||
# Operating system
|
||||
(r"^(os|operating_?system|os_?version|os_?name|platform|os_?type)$", "os"),
|
||||
# Source / destination IPs
|
||||
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress)$", "src_ip"),
|
||||
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress)$", "dst_ip"),
|
||||
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress|sourceip|laddr\.?ip)$", "src_ip"),
|
||||
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress|destination_?ip|destinationip|raddr\.?ip)$", "dst_ip"),
|
||||
(r"^(ip_?address|ipaddress|ip)$", "ip_address"),
|
||||
# Ports
|
||||
(r"^(source_?port|src_?port|localport)$", "src_port"),
|
||||
(r"^(dest_?port|dst_?port|remoteport|destinationport)$", "dst_port"),
|
||||
(r"^(source_?port|src_?port|localport|laddr\.?port)$", "src_port"),
|
||||
(r"^(dest_?port|dst_?port|remoteport|destinationport|raddr\.?port)$", "dst_port"),
|
||||
# Process info
|
||||
(r"^(process_?name|name|image|exe|executable|binary)$", "process_name"),
|
||||
(r"^(pid|process_?id)$", "pid"),
|
||||
@@ -51,6 +51,10 @@ COLUMN_MAPPINGS: list[tuple[str, str]] = [
|
||||
(r"^(protocol|proto)$", "protocol"),
|
||||
(r"^(domain|dns_?name|query_?name|queriedname)$", "domain"),
|
||||
(r"^(url|uri|request_?url)$", "url"),
|
||||
# MAC address
|
||||
(r"^(mac|mac_?address|physical_?address|mac_?addr|hw_?addr|ethernet)$", "mac_address"),
|
||||
# Connection state (netstat)
|
||||
(r"^(state|status|tcp_?state|conn_?state)$", "connection_state"),
|
||||
# Event info
|
||||
(r"^(event_?id|eventid|eid)$", "event_id"),
|
||||
(r"^(event_?type|eventtype|category|action)$", "event_type"),
|
||||
@@ -120,13 +124,27 @@ def detect_ioc_columns(
|
||||
"domain": "domain",
|
||||
}
|
||||
|
||||
# Canonical names that should NEVER be treated as IOCs even if values
|
||||
# match a pattern (e.g. process_name "svchost.exe" matching domain regex).
|
||||
_non_ioc_canonicals = frozenset({
|
||||
"process_name", "file_name", "file_path", "command_line",
|
||||
"parent_command_line", "description", "event_type", "registry_key",
|
||||
"registry_value", "severity", "os",
|
||||
"title", "netmask", "gateway", "connection_state",
|
||||
})
|
||||
|
||||
for col in columns:
|
||||
canonical = column_mapping.get(col, "")
|
||||
|
||||
# Skip columns whose canonical meaning is obviously not an IOC
|
||||
if canonical in _non_ioc_canonicals:
|
||||
continue
|
||||
|
||||
col_type = column_types.get(col)
|
||||
if col_type in ioc_type_map:
|
||||
ioc_columns[col] = ioc_type_map[col_type]
|
||||
|
||||
# Also check canonical name
|
||||
canonical = column_mapping.get(col, "")
|
||||
if canonical in ("src_ip", "dst_ip", "ip_address"):
|
||||
ioc_columns[col] = "ip"
|
||||
elif canonical == "hash_md5":
|
||||
@@ -139,6 +157,8 @@ def detect_ioc_columns(
|
||||
ioc_columns[col] = "domain"
|
||||
elif canonical == "url":
|
||||
ioc_columns[col] = "url"
|
||||
elif canonical == "hostname":
|
||||
ioc_columns[col] = "hostname"
|
||||
|
||||
return ioc_columns
|
||||
|
||||
|
||||
269
backend/app/services/playbook.py
Normal file
269
backend/app/services/playbook.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Investigation Notebook & Playbook Engine for ThreatHunt.
|
||||
|
||||
Notebooks: Analyst-facing, cell-based documents (markdown + code cells)
|
||||
stored as JSON in the database. Each cell can contain free-form
|
||||
markdown notes *or* a structured query/command that the backend
|
||||
evaluates against datasets.
|
||||
|
||||
Playbooks: Pre-defined, step-by-step investigation workflows. Each step
|
||||
defines an action (query, analyze, enrich, tag) and expected
|
||||
outcomes. Analysts can run through playbooks interactively or
|
||||
trigger them automatically on new alerts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Notebook helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class NotebookCell:
|
||||
id: str
|
||||
cell_type: str # markdown | query | code
|
||||
source: str
|
||||
output: Optional[str] = None
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def validate_notebook_cells(cells: list[dict]) -> list[dict]:
|
||||
"""Ensure each cell has required keys."""
|
||||
cleaned: list[dict] = []
|
||||
for i, c in enumerate(cells):
|
||||
cleaned.append({
|
||||
"id": c.get("id", f"cell-{i}"),
|
||||
"cell_type": c.get("cell_type", "markdown"),
|
||||
"source": c.get("source", ""),
|
||||
"output": c.get("output"),
|
||||
"metadata": c.get("metadata", {}),
|
||||
})
|
||||
return cleaned
|
||||
|
||||
|
||||
# ── Built-in Playbook Templates ──────────────────────────────────────
|
||||
|
||||
|
||||
BUILT_IN_PLAYBOOKS: list[dict[str, Any]] = [
|
||||
{
|
||||
"name": "Suspicious Process Investigation",
|
||||
"description": "Step-by-step investigation of a potentially malicious process execution.",
|
||||
"category": "incident_response",
|
||||
"tags": ["process", "malware", "T1059"],
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"title": "Identify the suspicious process",
|
||||
"description": "Search for the process name/PID across all datasets. Note the command line, parent, and user context.",
|
||||
"action": "search",
|
||||
"action_config": {"fields": ["process_name", "command_line", "parent_process_name", "username"]},
|
||||
"expected_outcome": "Process details, parent chain, and execution context identified.",
|
||||
},
|
||||
{
|
||||
"order": 2,
|
||||
"title": "Build process tree",
|
||||
"description": "View the full parent→child process tree to understand the execution chain.",
|
||||
"action": "process_tree",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Complete process lineage showing how the suspicious process was spawned.",
|
||||
},
|
||||
{
|
||||
"order": 3,
|
||||
"title": "Check network connections",
|
||||
"description": "Search for network events associated with the same host and timeframe.",
|
||||
"action": "search",
|
||||
"action_config": {"fields": ["src_ip", "dst_ip", "dst_port", "protocol"]},
|
||||
"expected_outcome": "Network connections revealing potential C2 or data exfiltration.",
|
||||
},
|
||||
{
|
||||
"order": 4,
|
||||
"title": "Run analyzers",
|
||||
"description": "Execute the suspicious_commands and entropy analyzers against the dataset.",
|
||||
"action": "analyze",
|
||||
"action_config": {"analyzers": ["suspicious_commands", "entropy"]},
|
||||
"expected_outcome": "Automated detection of known-bad patterns.",
|
||||
},
|
||||
{
|
||||
"order": 5,
|
||||
"title": "Map to MITRE ATT&CK",
|
||||
"description": "Check which MITRE techniques the process behavior maps to.",
|
||||
"action": "mitre_map",
|
||||
"action_config": {},
|
||||
"expected_outcome": "MITRE technique mappings for the suspicious activity.",
|
||||
},
|
||||
{
|
||||
"order": 6,
|
||||
"title": "Document findings & create case",
|
||||
"description": "Summarize investigation findings, annotate key evidence, and create a case if warranted.",
|
||||
"action": "create_case",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Investigation documented with annotations and optionally escalated.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Lateral Movement Hunt",
|
||||
"description": "Systematic hunt for lateral movement indicators across the environment.",
|
||||
"category": "threat_hunting",
|
||||
"tags": ["lateral_movement", "T1021", "T1047"],
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"title": "Search for remote access tools",
|
||||
"description": "Look for PsExec, WMI, WinRM, RDP, and SSH usage across datasets.",
|
||||
"action": "search",
|
||||
"action_config": {"query": "psexec|wmic|winrm|rdp|ssh"},
|
||||
"expected_outcome": "Identify all remote access tool usage.",
|
||||
},
|
||||
{
|
||||
"order": 2,
|
||||
"title": "Analyze authentication events",
|
||||
"description": "Run the auth anomaly analyzer to find brute force, unusual logon types.",
|
||||
"action": "analyze",
|
||||
"action_config": {"analyzers": ["auth_anomaly"]},
|
||||
"expected_outcome": "Authentication anomalies detected.",
|
||||
},
|
||||
{
|
||||
"order": 3,
|
||||
"title": "Check network anomalies",
|
||||
"description": "Run network anomaly analyzer for beaconing and suspicious connections.",
|
||||
"action": "analyze",
|
||||
"action_config": {"analyzers": ["network_anomaly"]},
|
||||
"expected_outcome": "Beaconing or unusual network patterns identified.",
|
||||
},
|
||||
{
|
||||
"order": 4,
|
||||
"title": "Build knowledge graph",
|
||||
"description": "Visualize entity relationships to identify pivot paths.",
|
||||
"action": "knowledge_graph",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Entity relationship graph showing lateral movement paths.",
|
||||
},
|
||||
{
|
||||
"order": 5,
|
||||
"title": "Document and escalate",
|
||||
"description": "Create annotations for key findings and escalate to case if needed.",
|
||||
"action": "create_case",
|
||||
"action_config": {"tags": ["lateral_movement"]},
|
||||
"expected_outcome": "Findings documented and case created.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Data Exfiltration Check",
|
||||
"description": "Investigate potential data exfiltration activity.",
|
||||
"category": "incident_response",
|
||||
"tags": ["exfiltration", "T1048", "T1567"],
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"title": "Identify large transfers",
|
||||
"description": "Search for network events with unusually high byte counts.",
|
||||
"action": "analyze",
|
||||
"action_config": {"analyzers": ["network_anomaly"], "config": {"large_transfer_threshold": 5000000}},
|
||||
"expected_outcome": "Large data transfers identified.",
|
||||
},
|
||||
{
|
||||
"order": 2,
|
||||
"title": "Check DNS anomalies",
|
||||
"description": "Look for DNS tunneling or unusual DNS query patterns.",
|
||||
"action": "search",
|
||||
"action_config": {"fields": ["dns_query", "query_length"]},
|
||||
"expected_outcome": "Suspicious DNS activity identified.",
|
||||
},
|
||||
{
|
||||
"order": 3,
|
||||
"title": "Timeline analysis",
|
||||
"description": "Examine the timeline for data staging and exfiltration windows.",
|
||||
"action": "timeline",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Time windows of suspicious activity identified.",
|
||||
},
|
||||
{
|
||||
"order": 4,
|
||||
"title": "Correlate with process activity",
|
||||
"description": "Match network exfiltration with process execution times.",
|
||||
"action": "search",
|
||||
"action_config": {"fields": ["process_name", "dst_ip", "bytes_sent"]},
|
||||
"expected_outcome": "Process responsible for data transfer identified.",
|
||||
},
|
||||
{
|
||||
"order": 5,
|
||||
"title": "MITRE mapping & documentation",
|
||||
"description": "Map findings to MITRE exfiltration techniques and document.",
|
||||
"action": "mitre_map",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Complete exfiltration investigation documented.",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Ransomware Triage",
|
||||
"description": "Rapid triage of potential ransomware activity.",
|
||||
"category": "incident_response",
|
||||
"tags": ["ransomware", "T1486", "T1490"],
|
||||
"steps": [
|
||||
{
|
||||
"order": 1,
|
||||
"title": "Search for ransomware indicators",
|
||||
"description": "Look for shadow copy deletion, boot config changes, encryption activity.",
|
||||
"action": "search",
|
||||
"action_config": {"query": "vssadmin|bcdedit|cipher|.encrypted|.locked|ransom"},
|
||||
"expected_outcome": "Ransomware indicators identified.",
|
||||
},
|
||||
{
|
||||
"order": 2,
|
||||
"title": "Run all analyzers",
|
||||
"description": "Execute all analyzers to get comprehensive threat picture.",
|
||||
"action": "analyze",
|
||||
"action_config": {},
|
||||
"expected_outcome": "Full automated analysis of ransomware indicators.",
|
||||
},
|
||||
{
|
||||
"order": 3,
|
||||
"title": "Check persistence mechanisms",
|
||||
"description": "Look for persistence that may indicate pre-ransomware staging.",
|
||||
"action": "analyze",
|
||||
"action_config": {"analyzers": ["persistence"]},
|
||||
"expected_outcome": "Persistence mechanisms identified.",
|
||||
},
|
||||
{
|
||||
"order": 4,
|
||||
"title": "LLM deep analysis",
|
||||
"description": "Run deep LLM analysis for comprehensive ransomware assessment.",
|
||||
"action": "llm_analyze",
|
||||
"action_config": {"mode": "deep", "focus": "ransomware"},
|
||||
"expected_outcome": "AI-powered ransomware analysis with recommendations.",
|
||||
},
|
||||
{
|
||||
"order": 5,
|
||||
"title": "Create critical case",
|
||||
"description": "Immediately create a critical-severity case for the ransomware incident.",
|
||||
"action": "create_case",
|
||||
"action_config": {"severity": "critical", "tags": ["ransomware"]},
|
||||
"expected_outcome": "Critical case created for incident response.",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_builtin_playbooks() -> list[dict]:
|
||||
"""Return list of all built-in playbook templates."""
|
||||
return BUILT_IN_PLAYBOOKS
|
||||
|
||||
|
||||
def get_playbook_template(name: str) -> dict | None:
|
||||
"""Get a specific built-in playbook by name."""
|
||||
for pb in BUILT_IN_PLAYBOOKS:
|
||||
if pb["name"] == name:
|
||||
return pb
|
||||
return None
|
||||
447
backend/app/services/process_tree.py
Normal file
447
backend/app/services/process_tree.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Process tree and storyline graph builder.
|
||||
|
||||
Extracts parent→child process relationships from dataset rows and builds
|
||||
hierarchical trees. Also builds attack-storyline graphs connecting events
|
||||
by host → process → network activity → file activity chains.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
_JUNK = frozenset({"", "N/A", "n/a", "-", "—", "null", "None", "none", "unknown"})
|
||||
|
||||
|
||||
def _clean(val: Any) -> str | None:
|
||||
"""Return cleaned string or None for junk values."""
|
||||
if val is None:
|
||||
return None
|
||||
s = str(val).strip()
|
||||
return None if s in _JUNK else s
|
||||
|
||||
|
||||
# ── Process Tree ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ProcessNode:
|
||||
"""A single process in the tree."""
|
||||
|
||||
__slots__ = (
|
||||
"pid", "ppid", "name", "command_line", "username", "hostname",
|
||||
"timestamp", "dataset_name", "row_index", "children", "extra",
|
||||
)
|
||||
|
||||
def __init__(self, **kw: Any):
|
||||
self.pid: str = kw.get("pid", "")
|
||||
self.ppid: str = kw.get("ppid", "")
|
||||
self.name: str = kw.get("name", "")
|
||||
self.command_line: str = kw.get("command_line", "")
|
||||
self.username: str = kw.get("username", "")
|
||||
self.hostname: str = kw.get("hostname", "")
|
||||
self.timestamp: str = kw.get("timestamp", "")
|
||||
self.dataset_name: str = kw.get("dataset_name", "")
|
||||
self.row_index: int = kw.get("row_index", -1)
|
||||
self.children: list["ProcessNode"] = []
|
||||
self.extra: dict = kw.get("extra", {})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"pid": self.pid,
|
||||
"ppid": self.ppid,
|
||||
"name": self.name,
|
||||
"command_line": self.command_line,
|
||||
"username": self.username,
|
||||
"hostname": self.hostname,
|
||||
"timestamp": self.timestamp,
|
||||
"dataset_name": self.dataset_name,
|
||||
"row_index": self.row_index,
|
||||
"children": [c.to_dict() for c in self.children],
|
||||
"extra": self.extra,
|
||||
}
|
||||
|
||||
|
||||
async def build_process_tree(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
hostname_filter: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Build process trees from dataset rows.
|
||||
|
||||
Returns a list of root-level process nodes (forest).
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Group processes by (hostname, pid) → node
|
||||
nodes_by_key: dict[tuple[str, str], ProcessNode] = {}
|
||||
nodes_list: list[ProcessNode] = []
|
||||
|
||||
for row_obj in rows:
|
||||
data = row_obj.normalized_data or row_obj.data
|
||||
pid = _clean(data.get("pid"))
|
||||
if not pid:
|
||||
continue
|
||||
|
||||
ppid = _clean(data.get("ppid")) or ""
|
||||
hostname = _clean(data.get("hostname")) or "unknown"
|
||||
|
||||
if hostname_filter and hostname.lower() != hostname_filter.lower():
|
||||
continue
|
||||
|
||||
node = ProcessNode(
|
||||
pid=pid,
|
||||
ppid=ppid,
|
||||
name=_clean(data.get("process_name")) or _clean(data.get("name")) or "",
|
||||
command_line=_clean(data.get("command_line")) or "",
|
||||
username=_clean(data.get("username")) or "",
|
||||
hostname=hostname,
|
||||
timestamp=_clean(data.get("timestamp")) or "",
|
||||
dataset_name=row_obj.dataset.name if row_obj.dataset else "",
|
||||
row_index=row_obj.row_index,
|
||||
extra={
|
||||
k: str(v)
|
||||
for k, v in data.items()
|
||||
if k not in {"pid", "ppid", "process_name", "name", "command_line",
|
||||
"username", "hostname", "timestamp"}
|
||||
and v is not None and str(v).strip() not in _JUNK
|
||||
},
|
||||
)
|
||||
|
||||
key = (hostname, pid)
|
||||
# Keep the first occurrence (earlier in data) or overwrite if deeper info
|
||||
if key not in nodes_by_key:
|
||||
nodes_by_key[key] = node
|
||||
nodes_list.append(node)
|
||||
|
||||
# Link parent → child
|
||||
roots: list[ProcessNode] = []
|
||||
for node in nodes_list:
|
||||
parent_key = (node.hostname, node.ppid)
|
||||
parent = nodes_by_key.get(parent_key)
|
||||
if parent and parent is not node:
|
||||
parent.children.append(node)
|
||||
else:
|
||||
roots.append(node)
|
||||
|
||||
return [r.to_dict() for r in roots]
|
||||
|
||||
|
||||
# ── Storyline / Attack Graph ─────────────────────────────────────────
|
||||
|
||||
|
||||
async def build_storyline(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
hostname_filter: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a CrowdStrike-style storyline graph.
|
||||
|
||||
Nodes represent events (process start, network connection, file write, etc.)
|
||||
Edges represent causal / temporal relationships.
|
||||
|
||||
Returns a Cytoscape-compatible elements dict {nodes: [...], edges: [...]}.
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
if not rows:
|
||||
return {"nodes": [], "edges": [], "summary": {}}
|
||||
|
||||
nodes: list[dict] = []
|
||||
edges: list[dict] = []
|
||||
seen_ids: set[str] = set()
|
||||
host_events: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for row_obj in rows:
|
||||
data = row_obj.normalized_data or row_obj.data
|
||||
hostname = _clean(data.get("hostname")) or "unknown"
|
||||
|
||||
if hostname_filter and hostname.lower() != hostname_filter.lower():
|
||||
continue
|
||||
|
||||
event_type = _classify_event(data)
|
||||
node_id = f"{row_obj.dataset_id}_{row_obj.row_index}"
|
||||
if node_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(node_id)
|
||||
|
||||
label = _build_label(data, event_type)
|
||||
severity = _estimate_severity(data, event_type)
|
||||
|
||||
node = {
|
||||
"data": {
|
||||
"id": node_id,
|
||||
"label": label,
|
||||
"event_type": event_type,
|
||||
"hostname": hostname,
|
||||
"timestamp": _clean(data.get("timestamp")) or "",
|
||||
"pid": _clean(data.get("pid")) or "",
|
||||
"ppid": _clean(data.get("ppid")) or "",
|
||||
"process_name": _clean(data.get("process_name")) or "",
|
||||
"command_line": _clean(data.get("command_line")) or "",
|
||||
"username": _clean(data.get("username")) or "",
|
||||
"src_ip": _clean(data.get("src_ip")) or "",
|
||||
"dst_ip": _clean(data.get("dst_ip")) or "",
|
||||
"dst_port": _clean(data.get("dst_port")) or "",
|
||||
"file_path": _clean(data.get("file_path")) or "",
|
||||
"severity": severity,
|
||||
"dataset_id": row_obj.dataset_id,
|
||||
"row_index": row_obj.row_index,
|
||||
},
|
||||
}
|
||||
nodes.append(node)
|
||||
host_events[hostname].append(node["data"])
|
||||
|
||||
# Build edges: parent→child (by pid/ppid) and temporal sequence per host
|
||||
pid_lookup: dict[tuple[str, str], str] = {} # (host, pid) → node_id
|
||||
for node in nodes:
|
||||
d = node["data"]
|
||||
if d["pid"]:
|
||||
pid_lookup[(d["hostname"], d["pid"])] = d["id"]
|
||||
|
||||
for node in nodes:
|
||||
d = node["data"]
|
||||
if d["ppid"]:
|
||||
parent_id = pid_lookup.get((d["hostname"], d["ppid"]))
|
||||
if parent_id and parent_id != d["id"]:
|
||||
edges.append({
|
||||
"data": {
|
||||
"id": f"e_{parent_id}_{d['id']}",
|
||||
"source": parent_id,
|
||||
"target": d["id"],
|
||||
"relationship": "spawned",
|
||||
}
|
||||
})
|
||||
|
||||
# Temporal edges within each host (sorted by timestamp)
|
||||
for hostname, events in host_events.items():
|
||||
sorted_events = sorted(events, key=lambda e: e.get("timestamp", ""))
|
||||
for i in range(len(sorted_events) - 1):
|
||||
src = sorted_events[i]
|
||||
tgt = sorted_events[i + 1]
|
||||
edge_id = f"t_{src['id']}_{tgt['id']}"
|
||||
# Avoid duplicate edges
|
||||
if not any(e["data"]["id"] == edge_id for e in edges):
|
||||
edges.append({
|
||||
"data": {
|
||||
"id": edge_id,
|
||||
"source": src["id"],
|
||||
"target": tgt["id"],
|
||||
"relationship": "temporal",
|
||||
}
|
||||
})
|
||||
|
||||
# Summary stats
|
||||
type_counts: dict[str, int] = defaultdict(int)
|
||||
for n in nodes:
|
||||
type_counts[n["data"]["event_type"]] += 1
|
||||
|
||||
summary = {
|
||||
"total_events": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"hosts": list(host_events.keys()),
|
||||
"event_types": dict(type_counts),
|
||||
}
|
||||
|
||||
return {"nodes": nodes, "edges": edges, "summary": summary}
|
||||
|
||||
|
||||
# ── Risk scoring for dashboard ────────────────────────────────────────
|
||||
|
||||
async def compute_risk_scores(
|
||||
db: AsyncSession,
|
||||
hunt_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Compute per-host risk scores from anomaly signals in datasets.
|
||||
|
||||
Returns {hosts: [{hostname, score, signals, ...}], overall_score, ...}
|
||||
"""
|
||||
rows = await _fetch_rows(db, hunt_id=hunt_id)
|
||||
if not rows:
|
||||
return {"hosts": [], "overall_score": 0, "total_events": 0,
|
||||
"severity_breakdown": {}}
|
||||
|
||||
host_signals: dict[str, dict] = defaultdict(
|
||||
lambda: {"hostname": "", "score": 0, "signals": [],
|
||||
"event_count": 0, "process_count": 0,
|
||||
"network_count": 0, "file_count": 0}
|
||||
)
|
||||
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0}
|
||||
|
||||
for row_obj in rows:
|
||||
data = row_obj.normalized_data or row_obj.data
|
||||
hostname = _clean(data.get("hostname")) or "unknown"
|
||||
entry = host_signals[hostname]
|
||||
entry["hostname"] = hostname
|
||||
entry["event_count"] += 1
|
||||
|
||||
event_type = _classify_event(data)
|
||||
severity = _estimate_severity(data, event_type)
|
||||
severity_counts[severity] = severity_counts.get(severity, 0) + 1
|
||||
|
||||
# Count event types
|
||||
if event_type == "process":
|
||||
entry["process_count"] += 1
|
||||
elif event_type == "network":
|
||||
entry["network_count"] += 1
|
||||
elif event_type == "file":
|
||||
entry["file_count"] += 1
|
||||
|
||||
# Risk signals
|
||||
cmd = (_clean(data.get("command_line")) or "").lower()
|
||||
proc = (_clean(data.get("process_name")) or "").lower()
|
||||
|
||||
# Detect suspicious patterns
|
||||
sus_patterns = [
|
||||
("powershell -enc", 8, "Encoded PowerShell"),
|
||||
("invoke-expression", 7, "PowerShell IEX"),
|
||||
("invoke-webrequest", 6, "PowerShell WebRequest"),
|
||||
("certutil -urlcache", 8, "Certutil download"),
|
||||
("bitsadmin /transfer", 7, "BITS transfer"),
|
||||
("regsvr32 /s /n /u", 8, "Regsvr32 squiblydoo"),
|
||||
("mshta ", 7, "MSHTA execution"),
|
||||
("wmic process", 6, "WMIC process enum"),
|
||||
("net user", 5, "User enumeration"),
|
||||
("whoami", 4, "Whoami recon"),
|
||||
("mimikatz", 10, "Mimikatz detected"),
|
||||
("procdump", 7, "Process dumping"),
|
||||
("psexec", 7, "PsExec lateral movement"),
|
||||
]
|
||||
|
||||
for pattern, score_add, signal_name in sus_patterns:
|
||||
if pattern in cmd or pattern in proc:
|
||||
entry["score"] += score_add
|
||||
if signal_name not in entry["signals"]:
|
||||
entry["signals"].append(signal_name)
|
||||
|
||||
# External connections score
|
||||
dst_ip = _clean(data.get("dst_ip")) or ""
|
||||
if dst_ip and not dst_ip.startswith(("10.", "192.168.", "172.")):
|
||||
entry["score"] += 1
|
||||
if "External connections" not in entry["signals"]:
|
||||
entry["signals"].append("External connections")
|
||||
|
||||
# Normalize scores (0-100)
|
||||
max_score = max((h["score"] for h in host_signals.values()), default=1)
|
||||
if max_score > 0:
|
||||
for entry in host_signals.values():
|
||||
entry["score"] = min(round((entry["score"] / max_score) * 100), 100)
|
||||
|
||||
hosts = sorted(host_signals.values(), key=lambda h: h["score"], reverse=True)
|
||||
overall = round(sum(h["score"] for h in hosts) / max(len(hosts), 1))
|
||||
|
||||
return {
|
||||
"hosts": hosts,
|
||||
"overall_score": overall,
|
||||
"total_events": sum(h["event_count"] for h in hosts),
|
||||
"severity_breakdown": severity_counts,
|
||||
}
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _fetch_rows(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
limit: int = 50_000,
|
||||
) -> Sequence[DatasetRow]:
|
||||
"""Fetch dataset rows, optionally filtered by dataset or hunt."""
|
||||
stmt = (
|
||||
select(DatasetRow)
|
||||
.join(Dataset)
|
||||
.options(selectinload(DatasetRow.dataset))
|
||||
)
|
||||
|
||||
if dataset_id:
|
||||
stmt = stmt.where(DatasetRow.dataset_id == dataset_id)
|
||||
elif hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
else:
|
||||
# No filter — limit to prevent OOM
|
||||
pass
|
||||
|
||||
stmt = stmt.order_by(DatasetRow.row_index).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
def _classify_event(data: dict) -> str:
|
||||
"""Classify a row as process / network / file / registry / other."""
|
||||
if _clean(data.get("pid")) or _clean(data.get("process_name")):
|
||||
if _clean(data.get("dst_ip")) or _clean(data.get("dst_port")):
|
||||
return "network"
|
||||
if _clean(data.get("file_path")):
|
||||
return "file"
|
||||
return "process"
|
||||
if _clean(data.get("dst_ip")) or _clean(data.get("src_ip")):
|
||||
return "network"
|
||||
if _clean(data.get("file_path")):
|
||||
return "file"
|
||||
if _clean(data.get("registry_key")):
|
||||
return "registry"
|
||||
return "other"
|
||||
|
||||
|
||||
def _build_label(data: dict, event_type: str) -> str:
|
||||
"""Build a concise node label for storyline display."""
|
||||
name = _clean(data.get("process_name")) or ""
|
||||
pid = _clean(data.get("pid")) or ""
|
||||
dst = _clean(data.get("dst_ip")) or ""
|
||||
port = _clean(data.get("dst_port")) or ""
|
||||
fpath = _clean(data.get("file_path")) or ""
|
||||
|
||||
if event_type == "process":
|
||||
return f"{name} (PID {pid})" if pid else name or "process"
|
||||
elif event_type == "network":
|
||||
target = f"{dst}:{port}" if dst and port else dst or port
|
||||
return f"{name} → {target}" if name else target or "network"
|
||||
elif event_type == "file":
|
||||
fname = fpath.split("\\")[-1].split("/")[-1] if fpath else ""
|
||||
return f"{name} → {fname}" if name else fname or "file"
|
||||
elif event_type == "registry":
|
||||
return _clean(data.get("registry_key")) or "registry"
|
||||
return name or "event"
|
||||
|
||||
|
||||
def _estimate_severity(data: dict, event_type: str) -> str:
|
||||
"""Rough heuristic severity estimate."""
|
||||
cmd = (_clean(data.get("command_line")) or "").lower()
|
||||
proc = (_clean(data.get("process_name")) or "").lower()
|
||||
|
||||
# Critical indicators
|
||||
critical_kw = ["mimikatz", "cobalt", "meterpreter", "empire", "bloodhound"]
|
||||
if any(k in cmd or k in proc for k in critical_kw):
|
||||
return "critical"
|
||||
|
||||
# High indicators
|
||||
high_kw = ["powershell -enc", "certutil -urlcache", "regsvr32", "mshta",
|
||||
"bitsadmin", "psexec", "procdump"]
|
||||
if any(k in cmd for k in high_kw):
|
||||
return "high"
|
||||
|
||||
# Medium indicators
|
||||
medium_kw = ["invoke-", "wmic", "net user", "net group", "schtasks",
|
||||
"reg add", "sc create"]
|
||||
if any(k in cmd for k in medium_kw):
|
||||
return "medium"
|
||||
|
||||
# Low: recon
|
||||
low_kw = ["whoami", "ipconfig", "systeminfo", "tasklist", "netstat"]
|
||||
if any(k in cmd for k in low_kw):
|
||||
return "low"
|
||||
|
||||
return "info"
|
||||
254
backend/app/services/timeline.py
Normal file
254
backend/app/services/timeline.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Timeline and field-statistics service.
|
||||
|
||||
Provides temporal histogram bins and per-field distribution stats
|
||||
for dataset rows — used by the TimelineScrubber and QueryBar components.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Timeline bins ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def build_timeline_bins(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
bins: int = 60,
|
||||
) -> dict:
|
||||
"""Create histogram bins of events over time.
|
||||
|
||||
Returns {bins: [{start, end, count, events_by_type}], total, range}.
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
if not rows:
|
||||
return {"bins": [], "total": 0, "range": None}
|
||||
|
||||
# Extract timestamps
|
||||
events: list[dict] = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
ts_str = data.get("timestamp", "")
|
||||
if not ts_str:
|
||||
continue
|
||||
ts = _parse_ts(str(ts_str))
|
||||
if ts:
|
||||
events.append({
|
||||
"timestamp": ts,
|
||||
"event_type": _classify_type(data),
|
||||
"hostname": data.get("hostname", ""),
|
||||
})
|
||||
|
||||
if not events:
|
||||
return {"bins": [], "total": len(rows), "range": None}
|
||||
|
||||
events.sort(key=lambda e: e["timestamp"])
|
||||
ts_min = events[0]["timestamp"]
|
||||
ts_max = events[-1]["timestamp"]
|
||||
|
||||
if ts_min == ts_max:
|
||||
return {
|
||||
"bins": [{"start": ts_min.isoformat(), "end": ts_max.isoformat(),
|
||||
"count": len(events), "events_by_type": {}}],
|
||||
"total": len(events),
|
||||
"range": {"start": ts_min.isoformat(), "end": ts_max.isoformat()},
|
||||
}
|
||||
|
||||
delta = (ts_max - ts_min) / bins
|
||||
result_bins: list[dict] = []
|
||||
|
||||
for i in range(bins):
|
||||
bin_start = ts_min + delta * i
|
||||
bin_end = ts_min + delta * (i + 1)
|
||||
bin_events = [e for e in events
|
||||
if bin_start <= e["timestamp"] < bin_end
|
||||
or (i == bins - 1 and e["timestamp"] == ts_max)]
|
||||
type_counts: dict[str, int] = Counter(e["event_type"] for e in bin_events)
|
||||
result_bins.append({
|
||||
"start": bin_start.isoformat(),
|
||||
"end": bin_end.isoformat(),
|
||||
"count": len(bin_events),
|
||||
"events_by_type": dict(type_counts),
|
||||
})
|
||||
|
||||
return {
|
||||
"bins": result_bins,
|
||||
"total": len(events),
|
||||
"range": {"start": ts_min.isoformat(), "end": ts_max.isoformat()},
|
||||
}
|
||||
|
||||
|
||||
# ── Field stats ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def compute_field_stats(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
fields: list[str] | None = None,
|
||||
top_n: int = 20,
|
||||
) -> dict:
|
||||
"""Compute per-field value distributions.
|
||||
|
||||
Returns {fields: {field_name: {total, unique, top: [{value, count}]}}}
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
if not rows:
|
||||
return {"fields": {}, "total_rows": 0}
|
||||
|
||||
# Determine which fields to analyze
|
||||
sample_data = rows[0].normalized_data or rows[0].data
|
||||
all_fields = list(sample_data.keys())
|
||||
target_fields = fields if fields else all_fields[:30]
|
||||
|
||||
stats: dict[str, dict] = {}
|
||||
for field in target_fields:
|
||||
values = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
v = data.get(field)
|
||||
if v is not None and str(v).strip() not in ("", "N/A", "n/a", "-", "None"):
|
||||
values.append(str(v))
|
||||
|
||||
counter = Counter(values)
|
||||
top = [{"value": v, "count": c} for v, c in counter.most_common(top_n)]
|
||||
stats[field] = {
|
||||
"total": len(values),
|
||||
"unique": len(counter),
|
||||
"top": top,
|
||||
}
|
||||
|
||||
return {
|
||||
"fields": stats,
|
||||
"total_rows": len(rows),
|
||||
"available_fields": all_fields,
|
||||
}
|
||||
|
||||
|
||||
# ── Row search with filters ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def search_rows(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
query: str = "",
|
||||
filters: dict[str, str] | None = None,
|
||||
time_start: str | None = None,
|
||||
time_end: str | None = None,
|
||||
limit: int = 500,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""Search/filter dataset rows.
|
||||
|
||||
Supports:
|
||||
- Free-text search across all fields
|
||||
- Field-specific filters {field: value}
|
||||
- Time range filters
|
||||
"""
|
||||
rows = await _fetch_rows(db, dataset_id=dataset_id, hunt_id=hunt_id, limit=50000)
|
||||
if not rows:
|
||||
return {"rows": [], "total": 0, "offset": offset, "limit": limit}
|
||||
|
||||
results: list[dict] = []
|
||||
ts_start = _parse_ts(time_start) if time_start else None
|
||||
ts_end = _parse_ts(time_end) if time_end else None
|
||||
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
|
||||
# Time filter
|
||||
if ts_start or ts_end:
|
||||
ts = _parse_ts(str(data.get("timestamp", "")))
|
||||
if ts:
|
||||
if ts_start and ts < ts_start:
|
||||
continue
|
||||
if ts_end and ts > ts_end:
|
||||
continue
|
||||
|
||||
# Field filters
|
||||
if filters:
|
||||
match = True
|
||||
for field, value in filters.items():
|
||||
field_val = str(data.get(field, "")).lower()
|
||||
if value.lower() not in field_val:
|
||||
match = False
|
||||
break
|
||||
if not match:
|
||||
continue
|
||||
|
||||
# Free-text search
|
||||
if query:
|
||||
q = query.lower()
|
||||
found = any(q in str(v).lower() for v in data.values())
|
||||
if not found:
|
||||
continue
|
||||
|
||||
results.append(data)
|
||||
|
||||
total = len(results)
|
||||
paged = results[offset:offset + limit]
|
||||
|
||||
return {"rows": paged, "total": total, "offset": offset, "limit": limit}
|
||||
|
||||
|
||||
# ── Internal helpers ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _fetch_rows(
|
||||
db: AsyncSession,
|
||||
dataset_id: str | None = None,
|
||||
hunt_id: str | None = None,
|
||||
limit: int = 50_000,
|
||||
) -> Sequence[DatasetRow]:
|
||||
stmt = select(DatasetRow).join(Dataset)
|
||||
if dataset_id:
|
||||
stmt = stmt.where(DatasetRow.dataset_id == dataset_id)
|
||||
elif hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
stmt = stmt.order_by(DatasetRow.row_index).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
def _parse_ts(ts_str: str | None) -> datetime | None:
|
||||
"""Best-effort timestamp parsing."""
|
||||
if not ts_str:
|
||||
return None
|
||||
for fmt in (
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ",
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%m/%d/%Y %H:%M:%S",
|
||||
"%m/%d/%Y %I:%M:%S %p",
|
||||
):
|
||||
try:
|
||||
return datetime.strptime(ts_str.strip(), fmt)
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def _classify_type(data: dict) -> str:
|
||||
if data.get("pid") or data.get("process_name"):
|
||||
if data.get("dst_ip") or data.get("dst_port"):
|
||||
return "network"
|
||||
return "process"
|
||||
if data.get("dst_ip") or data.get("src_ip"):
|
||||
return "network"
|
||||
if data.get("file_path"):
|
||||
return "file"
|
||||
return "other"
|
||||
Reference in New Issue
Block a user