chore: checkpoint all local changes

This commit is contained in:
2026-02-23 14:36:33 -05:00
76 changed files with 34486 additions and 738 deletions

View File

@@ -1,371 +1,296 @@
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
host grouping, anomaly detection, data query (SSE), and job management."""
from __future__ import annotations
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
import logging
from typing import Optional
from typing import Any, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import select
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import HostProfile, HuntReport, TriageResult
from app.db.repositories.datasets import DatasetRepository
from app.services.process_tree import (
build_process_tree,
build_storyline,
compute_risk_scores,
_fetch_rows,
)
from app.services.llm_analysis import (
AnalysisRequest,
AnalysisResult,
run_llm_analysis,
)
from app.services.timeline import (
build_timeline_bins,
compute_field_stats,
search_rows,
)
from app.services.mitre import (
map_to_attack,
build_knowledge_graph,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
# --- Response models ---
class TriageResultResponse(BaseModel):
id: str
dataset_id: str
row_start: int
row_end: int
risk_score: float
verdict: str
findings: list | None = None
suspicious_indicators: list | None = None
mitre_techniques: list | None = None
model_used: str | None = None
node_used: str | None = None
class Config:
from_attributes = True
# ── Response models ───────────────────────────────────────────────────
class HostProfileResponse(BaseModel):
id: str
hunt_id: str
class ProcessTreeResponse(BaseModel):
trees: list[dict] = Field(default_factory=list)
total_processes: int = 0
class StorylineResponse(BaseModel):
nodes: list[dict] = Field(default_factory=list)
edges: list[dict] = Field(default_factory=list)
summary: dict = Field(default_factory=dict)
class RiskHostEntry(BaseModel):
hostname: str
fqdn: str | None = None
risk_score: float
risk_level: str
artifact_summary: dict | None = None
timeline_summary: str | None = None
suspicious_findings: list | None = None
mitre_techniques: list | None = None
llm_analysis: str | None = None
model_used: str | None = None
class Config:
from_attributes = True
score: int = 0
signals: list[str] = Field(default_factory=list)
event_count: int = 0
process_count: int = 0
network_count: int = 0
file_count: int = 0
class HuntReportResponse(BaseModel):
id: str
hunt_id: str
status: str
exec_summary: str | None = None
full_report: str | None = None
findings: list | None = None
recommendations: list | None = None
mitre_mapping: dict | None = None
ioc_table: list | None = None
host_risk_summary: list | None = None
models_used: list | None = None
generation_time_ms: int | None = None
class Config:
from_attributes = True
class RiskSummaryResponse(BaseModel):
hosts: list[RiskHostEntry] = Field(default_factory=list)
overall_score: int = 0
total_events: int = 0
severity_breakdown: dict[str, int] = Field(default_factory=dict)
class QueryRequest(BaseModel):
question: str
mode: str = "quick" # quick or deep
# ── Routes ────────────────────────────────────────────────────────────
# --- Triage endpoints ---
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
async def get_triage_results(
dataset_id: str,
min_risk: float = Query(0.0, ge=0.0, le=10.0),
@router.get(
"/process-tree",
response_model=ProcessTreeResponse,
summary="Build process tree from dataset rows",
description=(
"Extracts parent→child process relationships from dataset rows "
"and returns a hierarchical forest of process nodes."
),
)
async def get_process_tree(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(TriageResult)
.where(TriageResult.dataset_id == dataset_id)
.where(TriageResult.risk_score >= min_risk)
.order_by(TriageResult.risk_score.desc())
"""Return process tree(s) for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
trees = await build_process_tree(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
return result.scalars().all()
# Count total processes recursively
def _count(node: dict) -> int:
return 1 + sum(_count(c) for c in node.get("children", []))
total = sum(_count(t) for t in trees)
return ProcessTreeResponse(trees=trees, total_processes=total)
@router.post("/triage/{dataset_id}")
async def trigger_triage(
dataset_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.triage import triage_dataset
await triage_dataset(dataset_id)
background_tasks.add_task(_run)
return {"status": "triage_started", "dataset_id": dataset_id}
# --- Host profile endpoints ---
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
async def get_host_profiles(
hunt_id: str,
min_risk: float = Query(0.0, ge=0.0, le=10.0),
@router.get(
"/storyline",
response_model=StorylineResponse,
summary="Build CrowdStrike-style storyline attack graph",
description=(
"Creates a Cytoscape-compatible graph of events connected by "
"process lineage (spawned) and temporal sequence within each host."
),
)
async def get_storyline(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HostProfile)
.where(HostProfile.hunt_id == hunt_id)
.where(HostProfile.risk_score >= min_risk)
.order_by(HostProfile.risk_score.desc())
"""Return a storyline graph for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
result = await build_storyline(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
return result.scalars().all()
return StorylineResponse(**result)
@router.post("/profiles/{hunt_id}")
async def trigger_all_profiles(
hunt_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.host_profiler import profile_all_hosts
await profile_all_hosts(hunt_id)
background_tasks.add_task(_run)
return {"status": "profiling_started", "hunt_id": hunt_id}
@router.post("/profiles/{hunt_id}/{hostname}")
async def trigger_single_profile(
hunt_id: str,
hostname: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.host_profiler import profile_host
await profile_host(hunt_id, hostname)
background_tasks.add_task(_run)
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
# --- Report endpoints ---
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
async def list_reports(
hunt_id: str,
@router.get(
"/risk-summary",
response_model=RiskSummaryResponse,
summary="Compute risk scores per host",
description=(
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
"credential dumping, lateral movement) and produces per-host risk scores."
),
)
async def get_risk_summary(
hunt_id: str | None = Query(None, description="Hunt ID"),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HuntReport)
.where(HuntReport.hunt_id == hunt_id)
.order_by(HuntReport.created_at.desc())
"""Return risk scores for all hosts in a hunt."""
result = await compute_risk_scores(db, hunt_id=hunt_id)
return RiskSummaryResponse(**result)
# ── LLM Analysis ─────────────────────────────────────────────────────
@router.post(
"/llm-analyze",
response_model=AnalysisResult,
summary="Run LLM-powered threat analysis on dataset",
description=(
"Loads dataset rows server-side, builds a summary, and sends to "
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
),
)
async def llm_analyze(
request: AnalysisRequest,
db: AsyncSession = Depends(get_db),
):
"""Run LLM analysis on a dataset or hunt."""
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
# Load rows
rows_objs = await _fetch_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
limit=2000,
)
return result.scalars().all()
if not rows_objs:
raise HTTPException(status_code=404, detail="No rows found for analysis")
# Extract data dicts
rows = [r.normalized_data or r.data for r in rows_objs]
# Get dataset name
ds_name = "hunt datasets"
if request.dataset_id:
repo = DatasetRepository(db)
ds = await repo.get_dataset(request.dataset_id)
if ds:
ds_name = ds.name
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
return result
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
async def get_report(
hunt_id: str,
report_id: str,
# ── Timeline ──────────────────────────────────────────────────────────
@router.get(
"/timeline",
summary="Get event timeline histogram bins",
)
async def get_timeline(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
bins: int = Query(60, ge=10, le=200),
db: AsyncSession = Depends(get_db),
):
result = await db.execute(
select(HuntReport)
.where(HuntReport.id == report_id)
.where(HuntReport.hunt_id == hunt_id)
)
report = result.scalar_one_or_none()
if not report:
raise HTTPException(status_code=404, detail="Report not found")
return report
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
@router.post("/reports/{hunt_id}/generate")
async def trigger_report(
hunt_id: str,
background_tasks: BackgroundTasks,
):
async def _run():
from app.services.report_generator import generate_report
await generate_report(hunt_id)
background_tasks.add_task(_run)
return {"status": "report_generation_started", "hunt_id": hunt_id}
# --- IOC extraction endpoints ---
@router.get("/iocs/{dataset_id}")
async def extract_iocs(
dataset_id: str,
max_rows: int = Query(5000, ge=1, le=50000),
@router.get(
"/field-stats",
summary="Get per-field value distributions",
)
async def get_field_stats(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
fields: str | None = Query(None, description="Comma-separated field names"),
top_n: int = Query(20, ge=5, le=100),
db: AsyncSession = Depends(get_db),
):
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
from app.services.ioc_extractor import extract_iocs_from_dataset
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
total = sum(len(v) for v in iocs.values())
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
# --- Host grouping endpoints ---
@router.get("/hosts/{hunt_id}")
async def get_host_groups(
hunt_id: str,
db: AsyncSession = Depends(get_db),
):
"""Group data by hostname across all datasets in a hunt."""
from app.services.ioc_extractor import extract_host_groups
groups = await extract_host_groups(hunt_id, db)
return {"hunt_id": hunt_id, "hosts": groups}
# --- Anomaly detection endpoints ---
@router.get("/anomalies/{dataset_id}")
async def get_anomalies(
dataset_id: str,
outliers_only: bool = Query(False),
db: AsyncSession = Depends(get_db),
):
"""Get anomaly detection results for a dataset."""
from app.db.models import AnomalyResult
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
if outliers_only:
stmt = stmt.where(AnomalyResult.is_outlier == True)
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
result = await db.execute(stmt)
rows = result.scalars().all()
return [
{
"id": r.id,
"dataset_id": r.dataset_id,
"row_id": r.row_id,
"anomaly_score": r.anomaly_score,
"distance_from_centroid": r.distance_from_centroid,
"cluster_id": r.cluster_id,
"is_outlier": r.is_outlier,
"explanation": r.explanation,
}
for r in rows
]
@router.post("/anomalies/{dataset_id}")
async def trigger_anomaly_detection(
dataset_id: str,
k: int = Query(3, ge=2, le=20),
threshold: float = Query(0.35, ge=0.1, le=0.9),
background_tasks: BackgroundTasks = None,
):
"""Trigger embedding-based anomaly detection on a dataset."""
async def _run():
from app.services.anomaly_detector import detect_anomalies
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
if background_tasks:
background_tasks.add_task(_run)
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
else:
from app.services.anomaly_detector import detect_anomalies
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
# --- Natural language data query (SSE streaming) ---
@router.post("/query/{dataset_id}")
async def query_dataset_endpoint(
dataset_id: str,
body: QueryRequest,
):
"""Ask a natural language question about a dataset.
Returns an SSE stream with token-by-token LLM response.
Event types: status, metadata, token, error, done
"""
from app.services.data_query import query_dataset_stream
return StreamingResponse(
query_dataset_stream(dataset_id, body.question, body.mode),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
field_list = [f.strip() for f in fields.split(",")] if fields else None
return await compute_field_stats(
db, dataset_id=dataset_id, hunt_id=hunt_id,
fields=field_list, top_n=top_n,
)
@router.post("/query/{dataset_id}/sync")
async def query_dataset_sync(
dataset_id: str,
body: QueryRequest,
class SearchRequest(BaseModel):
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
query: str = ""
filters: dict[str, str] = Field(default_factory=dict)
time_start: Optional[str] = None
time_end: Optional[str] = None
limit: int = 500
offset: int = 0
@router.post(
"/search",
summary="Search and filter dataset rows",
)
async def search_dataset_rows(
request: SearchRequest,
db: AsyncSession = Depends(get_db),
):
"""Non-streaming version of data query."""
from app.services.data_query import query_dataset
try:
answer = await query_dataset(dataset_id, body.question, body.mode)
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Query failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await search_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
query=request.query,
filters=request.filters,
time_start=request.time_start,
time_end=request.time_end,
limit=request.limit,
offset=request.offset,
)
# --- Job queue endpoints ---
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
@router.get("/jobs")
async def list_jobs(
status: str | None = Query(None),
job_type: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
@router.get(
"/mitre-map",
summary="Map dataset events to MITRE ATT&CK techniques",
)
async def get_mitre_map(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
"""List all tracked jobs."""
from app.services.job_queue import job_queue, JobStatus, JobType
s = JobStatus(status) if status else None
t = JobType(job_type) if job_type else None
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
stats = job_queue.get_stats()
return {"jobs": jobs, "stats": stats}
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
@router.get("/jobs/{job_id}")
async def get_job(job_id: str):
"""Get status of a specific job."""
from app.services.job_queue import job_queue
job = job_queue.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return job.to_dict()
@router.delete("/jobs/{job_id}")
async def cancel_job(job_id: str):
"""Cancel a running or queued job."""
from app.services.job_queue import job_queue
if job_queue.cancel_job(job_id):
return {"status": "cancelled", "job_id": job_id}
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
@router.post("/jobs/submit/{job_type}")
async def submit_job(
job_type: str,
params: dict = {},
@router.get(
"/knowledge-graph",
summary="Build entity-technique knowledge graph",
)
async def get_knowledge_graph(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
<<<<<<< HEAD
"""Submit a new job to the queue.
Job types: triage, host_profile, report, anomaly, query
@@ -403,4 +328,9 @@ async def lb_health_check():
"""Force a health check of both nodes."""
from app.services.load_balancer import lb
await lb.check_health()
return lb.get_status()
return lb.get_status()
=======
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2