mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
chore: checkpoint all local changes
This commit is contained in:
404
backend/app/api/routes/alerts.py
Normal file
404
backend/app/api/routes/alerts.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""API routes for alerts — CRUD, analyze triggers, and alert rules."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Alert, AlertRule, _new_id, _utcnow
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.analyzers import (
|
||||
get_available_analyzers,
|
||||
get_analyzer,
|
||||
run_all_analyzers,
|
||||
AlertCandidate,
|
||||
)
|
||||
from app.services.process_tree import _fetch_rows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AlertUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class RuleCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
analyzer: str
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: bool = True
|
||||
hunt_id: Optional[str] = None
|
||||
|
||||
|
||||
class RuleUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
analyzers: Optional[list[str]] = None # None = run all
|
||||
config: Optional[dict] = None
|
||||
auto_create: bool = True # automatically persist alerts
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _alert_to_dict(a: Alert) -> dict:
|
||||
return {
|
||||
"id": a.id,
|
||||
"title": a.title,
|
||||
"description": a.description,
|
||||
"severity": a.severity,
|
||||
"status": a.status,
|
||||
"analyzer": a.analyzer,
|
||||
"score": a.score,
|
||||
"evidence": a.evidence or [],
|
||||
"mitre_technique": a.mitre_technique,
|
||||
"tags": a.tags or [],
|
||||
"hunt_id": a.hunt_id,
|
||||
"dataset_id": a.dataset_id,
|
||||
"case_id": a.case_id,
|
||||
"assignee": a.assignee,
|
||||
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
|
||||
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
|
||||
"created_at": a.created_at.isoformat() if a.created_at else None,
|
||||
"updated_at": a.updated_at.isoformat() if a.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _rule_to_dict(r: AlertRule) -> dict:
|
||||
return {
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"analyzer": r.analyzer,
|
||||
"config": r.config,
|
||||
"severity_override": r.severity_override,
|
||||
"enabled": r.enabled,
|
||||
"hunt_id": r.hunt_id,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Alert CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List alerts")
|
||||
async def list_alerts(
|
||||
status: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
analyzer: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
dataset_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Alert)
|
||||
count_stmt = select(func.count(Alert.id))
|
||||
if status:
|
||||
stmt = stmt.where(Alert.status == status)
|
||||
count_stmt = count_stmt.where(Alert.status == status)
|
||||
if severity:
|
||||
stmt = stmt.where(Alert.severity == severity)
|
||||
count_stmt = count_stmt.where(Alert.severity == severity)
|
||||
if analyzer:
|
||||
stmt = stmt.where(Alert.analyzer == analyzer)
|
||||
count_stmt = count_stmt.where(Alert.analyzer == analyzer)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Alert.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Alert.hunt_id == hunt_id)
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Alert.dataset_id == dataset_id)
|
||||
count_stmt = count_stmt.where(Alert.dataset_id == dataset_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Alert.score), desc(Alert.created_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"alerts": [_alert_to_dict(a) for a in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/stats", summary="Alert statistics dashboard")
|
||||
async def alert_stats(
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return aggregated alert statistics."""
|
||||
base = select(Alert)
|
||||
if hunt_id:
|
||||
base = base.where(Alert.hunt_id == hunt_id)
|
||||
|
||||
# Severity breakdown
|
||||
sev_stmt = select(Alert.severity, func.count(Alert.id)).group_by(Alert.severity)
|
||||
if hunt_id:
|
||||
sev_stmt = sev_stmt.where(Alert.hunt_id == hunt_id)
|
||||
sev_rows = (await db.execute(sev_stmt)).all()
|
||||
severity_counts = {s: c for s, c in sev_rows}
|
||||
|
||||
# Status breakdown
|
||||
status_stmt = select(Alert.status, func.count(Alert.id)).group_by(Alert.status)
|
||||
if hunt_id:
|
||||
status_stmt = status_stmt.where(Alert.hunt_id == hunt_id)
|
||||
status_rows = (await db.execute(status_stmt)).all()
|
||||
status_counts = {s: c for s, c in status_rows}
|
||||
|
||||
# Analyzer breakdown
|
||||
analyzer_stmt = select(Alert.analyzer, func.count(Alert.id)).group_by(Alert.analyzer)
|
||||
if hunt_id:
|
||||
analyzer_stmt = analyzer_stmt.where(Alert.hunt_id == hunt_id)
|
||||
analyzer_rows = (await db.execute(analyzer_stmt)).all()
|
||||
analyzer_counts = {a: c for a, c in analyzer_rows}
|
||||
|
||||
# Top MITRE techniques
|
||||
mitre_stmt = (
|
||||
select(Alert.mitre_technique, func.count(Alert.id))
|
||||
.where(Alert.mitre_technique.isnot(None))
|
||||
.group_by(Alert.mitre_technique)
|
||||
.order_by(desc(func.count(Alert.id)))
|
||||
.limit(10)
|
||||
)
|
||||
if hunt_id:
|
||||
mitre_stmt = mitre_stmt.where(Alert.hunt_id == hunt_id)
|
||||
mitre_rows = (await db.execute(mitre_stmt)).all()
|
||||
top_mitre = [{"technique": t, "count": c} for t, c in mitre_rows]
|
||||
|
||||
total = sum(severity_counts.values())
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"severity_counts": severity_counts,
|
||||
"status_counts": status_counts,
|
||||
"analyzer_counts": analyzer_counts,
|
||||
"top_mitre": top_mitre,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{alert_id}", summary="Get alert detail")
|
||||
async def get_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.get(Alert, alert_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
return _alert_to_dict(result)
|
||||
|
||||
|
||||
@router.put("/{alert_id}", summary="Update alert (status, assignee, etc.)")
|
||||
async def update_alert(
|
||||
alert_id: str, body: AlertUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
|
||||
if body.status is not None:
|
||||
alert.status = body.status
|
||||
if body.status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if body.status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
if body.severity is not None:
|
||||
alert.severity = body.severity
|
||||
if body.assignee is not None:
|
||||
alert.assignee = body.assignee
|
||||
if body.case_id is not None:
|
||||
alert.case_id = body.case_id
|
||||
if body.tags is not None:
|
||||
alert.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(alert)
|
||||
return _alert_to_dict(alert)
|
||||
|
||||
|
||||
@router.delete("/{alert_id}", summary="Delete alert")
|
||||
async def delete_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
await db.delete(alert)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Bulk operations ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/bulk-update", summary="Bulk update alert statuses")
|
||||
async def bulk_update_alerts(
|
||||
alert_ids: list[str],
|
||||
status: str = Query(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
updated = 0
|
||||
for aid in alert_ids:
|
||||
alert = await db.get(Alert, aid)
|
||||
if alert:
|
||||
alert.status = status
|
||||
if status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
# ── Run Analyzers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/analyzers/list", summary="List available analyzers")
|
||||
async def list_analyzers():
|
||||
return {"analyzers": get_available_analyzers()}
|
||||
|
||||
|
||||
@router.post("/analyze", summary="Run analyzers on a dataset/hunt and optionally create alerts")
|
||||
async def run_analysis(
|
||||
request: AnalyzeRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db, dataset_id=request.dataset_id, hunt_id=request.hunt_id, limit=10000,
|
||||
)
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found")
|
||||
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Run analyzers
|
||||
candidates = await run_all_analyzers(rows, enabled=request.analyzers, config=request.config)
|
||||
|
||||
created_alerts: list[dict] = []
|
||||
if request.auto_create and candidates:
|
||||
for c in candidates:
|
||||
alert = Alert(
|
||||
id=_new_id(),
|
||||
title=c.title,
|
||||
description=c.description,
|
||||
severity=c.severity,
|
||||
analyzer=c.analyzer,
|
||||
score=c.score,
|
||||
evidence=c.evidence,
|
||||
mitre_technique=c.mitre_technique,
|
||||
tags=c.tags,
|
||||
hunt_id=request.hunt_id,
|
||||
dataset_id=request.dataset_id,
|
||||
)
|
||||
db.add(alert)
|
||||
created_alerts.append(_alert_to_dict(alert))
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"candidates_found": len(candidates),
|
||||
"alerts_created": len(created_alerts),
|
||||
"alerts": created_alerts,
|
||||
"summary": {
|
||||
"by_severity": _count_by(candidates, "severity"),
|
||||
"by_analyzer": _count_by(candidates, "analyzer"),
|
||||
"rows_analyzed": len(rows),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _count_by(items: list[AlertCandidate], attr: str) -> dict[str, int]:
|
||||
counts: dict[str, int] = {}
|
||||
for item in items:
|
||||
key = getattr(item, attr, "unknown")
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
# ── Alert Rules CRUD ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/rules/list", summary="List alert rules")
|
||||
async def list_rules(
|
||||
enabled: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(AlertRule)
|
||||
if enabled is not None:
|
||||
stmt = stmt.where(AlertRule.enabled == enabled)
|
||||
results = (await db.execute(stmt.order_by(AlertRule.created_at))).scalars().all()
|
||||
return {"rules": [_rule_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.post("/rules", summary="Create alert rule")
|
||||
async def create_rule(body: RuleCreate, db: AsyncSession = Depends(get_db)):
|
||||
# Validate analyzer exists
|
||||
if not get_analyzer(body.analyzer):
|
||||
raise HTTPException(status_code=400, detail=f"Unknown analyzer: {body.analyzer}")
|
||||
|
||||
rule = AlertRule(
|
||||
id=_new_id(),
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
analyzer=body.analyzer,
|
||||
config=body.config,
|
||||
severity_override=body.severity_override,
|
||||
enabled=body.enabled,
|
||||
hunt_id=body.hunt_id,
|
||||
)
|
||||
db.add(rule)
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.put("/rules/{rule_id}", summary="Update alert rule")
|
||||
async def update_rule(
|
||||
rule_id: str, body: RuleUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
|
||||
if body.name is not None:
|
||||
rule.name = body.name
|
||||
if body.description is not None:
|
||||
rule.description = body.description
|
||||
if body.config is not None:
|
||||
rule.config = body.config
|
||||
if body.severity_override is not None:
|
||||
rule.severity_override = body.severity_override
|
||||
if body.enabled is not None:
|
||||
rule.enabled = body.enabled
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", summary="Delete alert rule")
|
||||
async def delete_rule(rule_id: str, db: AsyncSession = Depends(get_db)):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
await db.delete(rule)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
@@ -1,371 +1,296 @@
|
||||
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
|
||||
host grouping, anomaly detection, data query (SSE), and job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import HostProfile, HuntReport, TriageResult
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.process_tree import (
|
||||
build_process_tree,
|
||||
build_storyline,
|
||||
compute_risk_scores,
|
||||
_fetch_rows,
|
||||
)
|
||||
from app.services.llm_analysis import (
|
||||
AnalysisRequest,
|
||||
AnalysisResult,
|
||||
run_llm_analysis,
|
||||
)
|
||||
from app.services.timeline import (
|
||||
build_timeline_bins,
|
||||
compute_field_stats,
|
||||
search_rows,
|
||||
)
|
||||
from app.services.mitre import (
|
||||
map_to_attack,
|
||||
build_knowledge_graph,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# --- Response models ---
|
||||
|
||||
class TriageResultResponse(BaseModel):
|
||||
id: str
|
||||
dataset_id: str
|
||||
row_start: int
|
||||
row_end: int
|
||||
risk_score: float
|
||||
verdict: str
|
||||
findings: list | None = None
|
||||
suspicious_indicators: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
model_used: str | None = None
|
||||
node_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HostProfileResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
class ProcessTreeResponse(BaseModel):
|
||||
trees: list[dict] = Field(default_factory=list)
|
||||
total_processes: int = 0
|
||||
|
||||
|
||||
class StorylineResponse(BaseModel):
|
||||
nodes: list[dict] = Field(default_factory=list)
|
||||
edges: list[dict] = Field(default_factory=list)
|
||||
summary: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RiskHostEntry(BaseModel):
|
||||
hostname: str
|
||||
fqdn: str | None = None
|
||||
risk_score: float
|
||||
risk_level: str
|
||||
artifact_summary: dict | None = None
|
||||
timeline_summary: str | None = None
|
||||
suspicious_findings: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
llm_analysis: str | None = None
|
||||
model_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
score: int = 0
|
||||
signals: list[str] = Field(default_factory=list)
|
||||
event_count: int = 0
|
||||
process_count: int = 0
|
||||
network_count: int = 0
|
||||
file_count: int = 0
|
||||
|
||||
|
||||
class HuntReportResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
status: str
|
||||
exec_summary: str | None = None
|
||||
full_report: str | None = None
|
||||
findings: list | None = None
|
||||
recommendations: list | None = None
|
||||
mitre_mapping: dict | None = None
|
||||
ioc_table: list | None = None
|
||||
host_risk_summary: list | None = None
|
||||
models_used: list | None = None
|
||||
generation_time_ms: int | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class RiskSummaryResponse(BaseModel):
|
||||
hosts: list[RiskHostEntry] = Field(default_factory=list)
|
||||
overall_score: int = 0
|
||||
total_events: int = 0
|
||||
severity_breakdown: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
mode: str = "quick" # quick or deep
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# --- Triage endpoints ---
|
||||
|
||||
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
|
||||
async def get_triage_results(
|
||||
dataset_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
@router.get(
|
||||
"/process-tree",
|
||||
response_model=ProcessTreeResponse,
|
||||
summary="Build process tree from dataset rows",
|
||||
description=(
|
||||
"Extracts parent→child process relationships from dataset rows "
|
||||
"and returns a hierarchical forest of process nodes."
|
||||
),
|
||||
)
|
||||
async def get_process_tree(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= min_risk)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
"""Return process tree(s) for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
trees = await build_process_tree(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
# Count total processes recursively
|
||||
def _count(node: dict) -> int:
|
||||
return 1 + sum(_count(c) for c in node.get("children", []))
|
||||
|
||||
total = sum(_count(t) for t in trees)
|
||||
|
||||
return ProcessTreeResponse(trees=trees, total_processes=total)
|
||||
|
||||
|
||||
@router.post("/triage/{dataset_id}")
|
||||
async def trigger_triage(
|
||||
dataset_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.triage import triage_dataset
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "triage_started", "dataset_id": dataset_id}
|
||||
|
||||
|
||||
# --- Host profile endpoints ---
|
||||
|
||||
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
|
||||
async def get_host_profiles(
|
||||
hunt_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
@router.get(
|
||||
"/storyline",
|
||||
response_model=StorylineResponse,
|
||||
summary="Build CrowdStrike-style storyline attack graph",
|
||||
description=(
|
||||
"Creates a Cytoscape-compatible graph of events connected by "
|
||||
"process lineage (spawned) and temporal sequence within each host."
|
||||
),
|
||||
)
|
||||
async def get_storyline(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.where(HostProfile.risk_score >= min_risk)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
"""Return a storyline graph for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
result = await build_storyline(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
return StorylineResponse(**result)
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}")
|
||||
async def trigger_all_profiles(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_all_hosts
|
||||
await profile_all_hosts(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}/{hostname}")
|
||||
async def trigger_single_profile(
|
||||
hunt_id: str,
|
||||
hostname: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_host
|
||||
await profile_host(hunt_id, hostname)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
|
||||
|
||||
|
||||
# --- Report endpoints ---
|
||||
|
||||
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
|
||||
async def list_reports(
|
||||
hunt_id: str,
|
||||
@router.get(
|
||||
"/risk-summary",
|
||||
response_model=RiskSummaryResponse,
|
||||
summary="Compute risk scores per host",
|
||||
description=(
|
||||
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
|
||||
"credential dumping, lateral movement) and produces per-host risk scores."
|
||||
),
|
||||
)
|
||||
async def get_risk_summary(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
.order_by(HuntReport.created_at.desc())
|
||||
"""Return risk scores for all hosts in a hunt."""
|
||||
result = await compute_risk_scores(db, hunt_id=hunt_id)
|
||||
return RiskSummaryResponse(**result)
|
||||
|
||||
|
||||
# ── LLM Analysis ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm-analyze",
|
||||
response_model=AnalysisResult,
|
||||
summary="Run LLM-powered threat analysis on dataset",
|
||||
description=(
|
||||
"Loads dataset rows server-side, builds a summary, and sends to "
|
||||
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
|
||||
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
|
||||
),
|
||||
)
|
||||
async def llm_analyze(
|
||||
request: AnalysisRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Run LLM analysis on a dataset or hunt."""
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
limit=2000,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found for analysis")
|
||||
|
||||
# Extract data dicts
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Get dataset name
|
||||
ds_name = "hunt datasets"
|
||||
if request.dataset_id:
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(request.dataset_id)
|
||||
if ds:
|
||||
ds_name = ds.name
|
||||
|
||||
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
|
||||
async def get_report(
|
||||
hunt_id: str,
|
||||
report_id: str,
|
||||
# ── Timeline ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/timeline",
|
||||
summary="Get event timeline histogram bins",
|
||||
)
|
||||
async def get_timeline(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
bins: int = Query(60, ge=10, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.id == report_id)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return report
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
|
||||
|
||||
|
||||
@router.post("/reports/{hunt_id}/generate")
|
||||
async def trigger_report(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.report_generator import generate_report
|
||||
await generate_report(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "report_generation_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
# --- IOC extraction endpoints ---
|
||||
|
||||
@router.get("/iocs/{dataset_id}")
|
||||
async def extract_iocs(
|
||||
dataset_id: str,
|
||||
max_rows: int = Query(5000, ge=1, le=50000),
|
||||
@router.get(
|
||||
"/field-stats",
|
||||
summary="Get per-field value distributions",
|
||||
)
|
||||
async def get_field_stats(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
fields: str | None = Query(None, description="Comma-separated field names"),
|
||||
top_n: int = Query(20, ge=5, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
|
||||
|
||||
|
||||
# --- Host grouping endpoints ---
|
||||
|
||||
@router.get("/hosts/{hunt_id}")
|
||||
async def get_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Group data by hostname across all datasets in a hunt."""
|
||||
from app.services.ioc_extractor import extract_host_groups
|
||||
groups = await extract_host_groups(hunt_id, db)
|
||||
return {"hunt_id": hunt_id, "hosts": groups}
|
||||
|
||||
|
||||
# --- Anomaly detection endpoints ---
|
||||
|
||||
@router.get("/anomalies/{dataset_id}")
|
||||
async def get_anomalies(
|
||||
dataset_id: str,
|
||||
outliers_only: bool = Query(False),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get anomaly detection results for a dataset."""
|
||||
from app.db.models import AnomalyResult
|
||||
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
|
||||
if outliers_only:
|
||||
stmt = stmt.where(AnomalyResult.is_outlier == True)
|
||||
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"dataset_id": r.dataset_id,
|
||||
"row_id": r.row_id,
|
||||
"anomaly_score": r.anomaly_score,
|
||||
"distance_from_centroid": r.distance_from_centroid,
|
||||
"cluster_id": r.cluster_id,
|
||||
"is_outlier": r.is_outlier,
|
||||
"explanation": r.explanation,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/anomalies/{dataset_id}")
|
||||
async def trigger_anomaly_detection(
|
||||
dataset_id: str,
|
||||
k: int = Query(3, ge=2, le=20),
|
||||
threshold: float = Query(0.35, ge=0.1, le=0.9),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Trigger embedding-based anomaly detection on a dataset."""
|
||||
async def _run():
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
|
||||
else:
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
|
||||
|
||||
|
||||
# --- Natural language data query (SSE streaming) ---
|
||||
|
||||
@router.post("/query/{dataset_id}")
|
||||
async def query_dataset_endpoint(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Ask a natural language question about a dataset.
|
||||
|
||||
Returns an SSE stream with token-by-token LLM response.
|
||||
Event types: status, metadata, token, error, done
|
||||
"""
|
||||
from app.services.data_query import query_dataset_stream
|
||||
|
||||
return StreamingResponse(
|
||||
query_dataset_stream(dataset_id, body.question, body.mode),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
field_list = [f.strip() for f in fields.split(",")] if fields else None
|
||||
return await compute_field_stats(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id,
|
||||
fields=field_list, top_n=top_n,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/query/{dataset_id}/sync")
|
||||
async def query_dataset_sync(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
class SearchRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
query: str = ""
|
||||
filters: dict[str, str] = Field(default_factory=dict)
|
||||
time_start: Optional[str] = None
|
||||
time_end: Optional[str] = None
|
||||
limit: int = 500
|
||||
offset: int = 0
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Search and filter dataset rows",
|
||||
)
|
||||
async def search_dataset_rows(
|
||||
request: SearchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Non-streaming version of data query."""
|
||||
from app.services.data_query import query_dataset
|
||||
|
||||
try:
|
||||
answer = await query_dataset(dataset_id, body.question, body.mode)
|
||||
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await search_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
time_start=request.time_start,
|
||||
time_end=request.time_end,
|
||||
limit=request.limit,
|
||||
offset=request.offset,
|
||||
)
|
||||
|
||||
|
||||
# --- Job queue endpoints ---
|
||||
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/jobs")
|
||||
async def list_jobs(
|
||||
status: str | None = Query(None),
|
||||
job_type: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
|
||||
@router.get(
|
||||
"/mitre-map",
|
||||
summary="Map dataset events to MITRE ATT&CK techniques",
|
||||
)
|
||||
async def get_mitre_map(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all tracked jobs."""
|
||||
from app.services.job_queue import job_queue, JobStatus, JobType
|
||||
|
||||
s = JobStatus(status) if status else None
|
||||
t = JobType(job_type) if job_type else None
|
||||
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
|
||||
stats = job_queue.get_stats()
|
||||
return {"jobs": jobs, "stats": stats}
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(job_id: str):
|
||||
"""Get status of a specific job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
job = job_queue.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job.to_dict()
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}")
|
||||
async def cancel_job(job_id: str):
|
||||
"""Cancel a running or queued job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
if job_queue.cancel_job(job_id):
|
||||
return {"status": "cancelled", "job_id": job_id}
|
||||
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
|
||||
|
||||
|
||||
@router.post("/jobs/submit/{job_type}")
|
||||
async def submit_job(
|
||||
job_type: str,
|
||||
params: dict = {},
|
||||
@router.get(
|
||||
"/knowledge-graph",
|
||||
summary="Build entity-technique knowledge graph",
|
||||
)
|
||||
async def get_knowledge_graph(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
<<<<<<< HEAD
|
||||
"""Submit a new job to the queue.
|
||||
|
||||
Job types: triage, host_profile, report, anomaly, query
|
||||
@@ -403,4 +328,9 @@ async def lb_health_check():
|
||||
"""Force a health check of both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
await lb.check_health()
|
||||
return lb.get_status()
|
||||
return lb.get_status()
|
||||
=======
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
296
backend/app/api/routes/cases.py
Normal file
296
backend/app/api/routes/cases.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""API routes for case management — CRUD for cases, tasks, and activity logs."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Case, CaseTask, ActivityLog, _new_id, _utcnow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/cases", tags=["cases"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
class CaseCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
severity: str = "medium"
|
||||
tlp: str = "amber"
|
||||
pap: str = "amber"
|
||||
priority: int = 2
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class CaseUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
tlp: Optional[str] = None
|
||||
pap: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
order: Optional[int] = None
|
||||
|
||||
|
||||
# ── Helper: log activity ─────────────────────────────────────────────
|
||||
|
||||
async def _log_activity(
|
||||
db: AsyncSession,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
action: str,
|
||||
details: dict | None = None,
|
||||
):
|
||||
log = ActivityLog(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
action=action,
|
||||
details=details,
|
||||
created_at=_utcnow(),
|
||||
)
|
||||
db.add(log)
|
||||
|
||||
|
||||
# ── Case CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", summary="Create a case")
|
||||
async def create_case(body: CaseCreate, db: AsyncSession = Depends(get_db)):
|
||||
now = _utcnow()
|
||||
case = Case(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
severity=body.severity,
|
||||
tlp=body.tlp,
|
||||
pap=body.pap,
|
||||
priority=body.priority,
|
||||
assignee=body.assignee,
|
||||
tags=body.tags,
|
||||
hunt_id=body.hunt_id,
|
||||
mitre_techniques=body.mitre_techniques,
|
||||
iocs=body.iocs,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(case)
|
||||
await _log_activity(db, "case", case.id, "created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.get("", summary="List cases")
|
||||
async def list_cases(
|
||||
status: Optional[str] = Query(None),
|
||||
hunt_id: Optional[str] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Case).order_by(desc(Case.updated_at))
|
||||
if status:
|
||||
q = q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
q = q.where(Case.hunt_id == hunt_id)
|
||||
q = q.offset(offset).limit(limit)
|
||||
result = await db.execute(q)
|
||||
cases = result.scalars().all()
|
||||
|
||||
count_q = select(func.count(Case.id))
|
||||
if status:
|
||||
count_q = count_q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
count_q = count_q.where(Case.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_q)).scalar() or 0
|
||||
|
||||
return {"cases": [_case_to_dict(c) for c in cases], "total": total}
|
||||
|
||||
|
||||
@router.get("/{case_id}", summary="Get case detail")
|
||||
async def get_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.put("/{case_id}", summary="Update a case")
|
||||
async def update_case(case_id: str, body: CaseUpdate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
changes = {}
|
||||
for field in ["title", "description", "severity", "tlp", "pap", "status",
|
||||
"priority", "assignee", "tags", "mitre_techniques", "iocs"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
old = getattr(case, field)
|
||||
setattr(case, field, val)
|
||||
changes[field] = {"old": old, "new": val}
|
||||
if "status" in changes and changes["status"]["new"] == "in-progress" and not case.started_at:
|
||||
case.started_at = _utcnow()
|
||||
if "status" in changes and changes["status"]["new"] in ("resolved", "closed"):
|
||||
case.resolved_at = _utcnow()
|
||||
case.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case.id, "updated", changes)
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.delete("/{case_id}", summary="Delete a case")
|
||||
async def delete_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
await db.delete(case)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Task CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{case_id}/tasks", summary="Add task to case")
|
||||
async def create_task(case_id: str, body: TaskCreate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
now = _utcnow()
|
||||
task = CaseTask(
|
||||
id=_new_id(),
|
||||
case_id=case_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
assignee=body.assignee,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(task)
|
||||
await _log_activity(db, "case", case_id, "task_created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.put("/{case_id}/tasks/{task_id}", summary="Update a task")
|
||||
async def update_task(case_id: str, task_id: str, body: TaskUpdate, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
for field in ["title", "description", "status", "assignee", "order"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
setattr(task, field, val)
|
||||
task.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case_id, "task_updated", {"task_id": task_id})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.delete("/{case_id}/tasks/{task_id}", summary="Delete a task")
|
||||
async def delete_task(case_id: str, task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Activity Log ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{case_id}/activity", summary="Get case activity log")
|
||||
async def get_activity(
|
||||
case_id: str,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = (
|
||||
select(ActivityLog)
|
||||
.where(ActivityLog.entity_type == "case", ActivityLog.entity_id == case_id)
|
||||
.order_by(desc(ActivityLog.created_at))
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(q)
|
||||
logs = result.scalars().all()
|
||||
return {
|
||||
"logs": [
|
||||
{
|
||||
"id": l.id,
|
||||
"action": l.action,
|
||||
"details": l.details,
|
||||
"user_id": l.user_id,
|
||||
"created_at": l.created_at.isoformat() if l.created_at else None,
|
||||
}
|
||||
for l in logs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _case_to_dict(c: Case) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"title": c.title,
|
||||
"description": c.description,
|
||||
"severity": c.severity,
|
||||
"tlp": c.tlp,
|
||||
"pap": c.pap,
|
||||
"status": c.status,
|
||||
"priority": c.priority,
|
||||
"assignee": c.assignee,
|
||||
"tags": c.tags or [],
|
||||
"hunt_id": c.hunt_id,
|
||||
"owner_id": c.owner_id,
|
||||
"mitre_techniques": c.mitre_techniques or [],
|
||||
"iocs": c.iocs or [],
|
||||
"started_at": c.started_at.isoformat() if c.started_at else None,
|
||||
"resolved_at": c.resolved_at.isoformat() if c.resolved_at else None,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
"tasks": [_task_to_dict(t) for t in (c.tasks or [])],
|
||||
}
|
||||
|
||||
|
||||
def _task_to_dict(t: CaseTask) -> dict:
|
||||
return {
|
||||
"id": t.id,
|
||||
"case_id": t.case_id,
|
||||
"title": t.title,
|
||||
"description": t.description,
|
||||
"status": t.status,
|
||||
"assignee": t.assignee,
|
||||
"order": t.order,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
|
||||
}
|
||||
@@ -398,3 +398,30 @@ async def delete_dataset(
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rescan-ioc",
|
||||
summary="Re-scan IOC columns for all datasets",
|
||||
)
|
||||
async def rescan_ioc_columns(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
|
||||
repo = DatasetRepository(db)
|
||||
all_ds = await repo.list_datasets(limit=10000)
|
||||
updated = 0
|
||||
for ds in all_ds:
|
||||
columns = list((ds.column_schema or {}).keys())
|
||||
if not columns:
|
||||
continue
|
||||
new_ioc = detect_ioc_columns(
|
||||
columns,
|
||||
ds.column_schema or {},
|
||||
ds.normalized_columns or {},
|
||||
)
|
||||
if new_ioc != (ds.ioc_columns or {}):
|
||||
ds.ioc_columns = new_ioc
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}
|
||||
|
||||
@@ -1,20 +1,34 @@
|
||||
<<<<<<< HEAD
|
||||
"""Network topology API - host inventory endpoint with background caching."""
|
||||
=======
|
||||
"""API routes for Network Picture — deduplicated host inventory."""
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
<<<<<<< HEAD
|
||||
from fastapi.responses import JSONResponse
|
||||
=======
|
||||
from pydantic import BaseModel, Field
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
<<<<<<< HEAD
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
=======
|
||||
from app.services.network_inventory import build_network_picture
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
@@ -173,3 +187,58 @@ async def trigger_rebuild(
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
=======
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HostEntry(BaseModel):
|
||||
hostname: str
|
||||
ips: list[str] = Field(default_factory=list)
|
||||
users: list[str] = Field(default_factory=list)
|
||||
os: list[str] = Field(default_factory=list)
|
||||
mac_addresses: list[str] = Field(default_factory=list)
|
||||
protocols: list[str] = Field(default_factory=list)
|
||||
open_ports: list[str] = Field(default_factory=list)
|
||||
remote_targets: list[str] = Field(default_factory=list)
|
||||
datasets: list[str] = Field(default_factory=list)
|
||||
connection_count: int = 0
|
||||
first_seen: str | None = None
|
||||
last_seen: str | None = None
|
||||
|
||||
|
||||
class PictureSummary(BaseModel):
|
||||
total_hosts: int = 0
|
||||
total_connections: int = 0
|
||||
total_unique_ips: int = 0
|
||||
datasets_scanned: int = 0
|
||||
|
||||
|
||||
class NetworkPictureResponse(BaseModel):
|
||||
hosts: list[HostEntry]
|
||||
summary: PictureSummary
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/picture",
|
||||
response_model=NetworkPictureResponse,
|
||||
summary="Build deduplicated host inventory for a hunt",
|
||||
description=(
|
||||
"Scans all datasets in the specified hunt, extracts host-identifying "
|
||||
"fields (hostname, IP, username, OS, MAC, ports), deduplicates by "
|
||||
"hostname, and returns a clean one-row-per-host network picture."
|
||||
),
|
||||
)
|
||||
async def get_network_picture(
|
||||
hunt_id: str = Query(..., description="Hunt ID to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a deduplicated network picture for a hunt."""
|
||||
if not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="hunt_id is required")
|
||||
|
||||
result = await build_network_picture(db, hunt_id)
|
||||
return result
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
360
backend/app/api/routes/notebooks.py
Normal file
360
backend/app/api/routes/notebooks.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""API routes for investigation notebooks and playbooks."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Notebook, PlaybookRun, _new_id, _utcnow
|
||||
from app.services.playbook import (
|
||||
get_builtin_playbooks,
|
||||
get_playbook_template,
|
||||
validate_notebook_cells,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/notebooks", tags=["notebooks"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NotebookCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class NotebookUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class CellUpdate(BaseModel):
|
||||
"""Update a single cell or add a new one."""
|
||||
cell_id: str
|
||||
cell_type: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class PlaybookStart(BaseModel):
|
||||
playbook_name: str
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
started_by: Optional[str] = None
|
||||
|
||||
|
||||
class StepComplete(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
status: str = "completed" # completed | skipped
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _notebook_to_dict(nb: Notebook) -> dict:
|
||||
return {
|
||||
"id": nb.id,
|
||||
"title": nb.title,
|
||||
"description": nb.description,
|
||||
"cells": nb.cells or [],
|
||||
"hunt_id": nb.hunt_id,
|
||||
"case_id": nb.case_id,
|
||||
"owner_id": nb.owner_id,
|
||||
"tags": nb.tags or [],
|
||||
"cell_count": len(nb.cells or []),
|
||||
"created_at": nb.created_at.isoformat() if nb.created_at else None,
|
||||
"updated_at": nb.updated_at.isoformat() if nb.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _run_to_dict(run: PlaybookRun) -> dict:
|
||||
return {
|
||||
"id": run.id,
|
||||
"playbook_name": run.playbook_name,
|
||||
"status": run.status,
|
||||
"current_step": run.current_step,
|
||||
"total_steps": run.total_steps,
|
||||
"step_results": run.step_results or [],
|
||||
"hunt_id": run.hunt_id,
|
||||
"case_id": run.case_id,
|
||||
"started_by": run.started_by,
|
||||
"created_at": run.created_at.isoformat() if run.created_at else None,
|
||||
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
|
||||
"completed_at": run.completed_at.isoformat() if run.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Notebook CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List notebooks")
|
||||
async def list_notebooks(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Notebook)
|
||||
count_stmt = select(func.count(Notebook.id))
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Notebook.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Notebook.hunt_id == hunt_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Notebook.updated_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"notebooks": [_notebook_to_dict(n) for n in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/{notebook_id}", summary="Get notebook")
|
||||
async def get_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("", summary="Create notebook")
|
||||
async def create_notebook(body: NotebookCreate, db: AsyncSession = Depends(get_db)):
|
||||
cells = validate_notebook_cells(body.cells or [])
|
||||
if not cells:
|
||||
# Start with a default markdown cell
|
||||
cells = [{"id": "cell-0", "cell_type": "markdown", "source": "# Investigation Notes\n\nStart documenting your findings here.", "output": None, "metadata": {}}]
|
||||
|
||||
nb = Notebook(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
cells=cells,
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
tags=body.tags,
|
||||
)
|
||||
db.add(nb)
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.put("/{notebook_id}", summary="Update notebook")
|
||||
async def update_notebook(
|
||||
notebook_id: str, body: NotebookUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
if body.title is not None:
|
||||
nb.title = body.title
|
||||
if body.description is not None:
|
||||
nb.description = body.description
|
||||
if body.cells is not None:
|
||||
nb.cells = validate_notebook_cells(body.cells)
|
||||
if body.tags is not None:
|
||||
nb.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("/{notebook_id}/cells", summary="Add or update a cell")
|
||||
async def upsert_cell(
|
||||
notebook_id: str, body: CellUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = list(nb.cells or [])
|
||||
found = False
|
||||
for i, c in enumerate(cells):
|
||||
if c.get("id") == body.cell_id:
|
||||
if body.cell_type is not None:
|
||||
cells[i]["cell_type"] = body.cell_type
|
||||
if body.source is not None:
|
||||
cells[i]["source"] = body.source
|
||||
if body.output is not None:
|
||||
cells[i]["output"] = body.output
|
||||
if body.metadata is not None:
|
||||
cells[i]["metadata"] = body.metadata
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
cells.append({
|
||||
"id": body.cell_id,
|
||||
"cell_type": body.cell_type or "markdown",
|
||||
"source": body.source or "",
|
||||
"output": body.output,
|
||||
"metadata": body.metadata or {},
|
||||
})
|
||||
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}/cells/{cell_id}", summary="Delete a cell")
|
||||
async def delete_cell(
|
||||
notebook_id: str, cell_id: str, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = [c for c in (nb.cells or []) if c.get("id") != cell_id]
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
return {"ok": True, "remaining_cells": len(cells)}
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}", summary="Delete notebook")
|
||||
async def delete_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
await db.delete(nb)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Playbooks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/playbooks/templates", summary="List built-in playbook templates")
|
||||
async def list_playbook_templates():
|
||||
templates = get_builtin_playbooks()
|
||||
return {
|
||||
"templates": [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"category": t["category"],
|
||||
"tags": t["tags"],
|
||||
"step_count": len(t["steps"]),
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/playbooks/templates/{name}", summary="Get playbook template detail")
|
||||
async def get_playbook_template_detail(name: str):
|
||||
template = get_playbook_template(name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.post("/playbooks/start", summary="Start a playbook run")
|
||||
async def start_playbook(body: PlaybookStart, db: AsyncSession = Depends(get_db)):
|
||||
template = get_playbook_template(body.playbook_name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
|
||||
run = PlaybookRun(
|
||||
id=_new_id(),
|
||||
playbook_name=body.playbook_name,
|
||||
status="in-progress",
|
||||
current_step=1,
|
||||
total_steps=len(template["steps"]),
|
||||
step_results=[],
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
started_by=body.started_by,
|
||||
)
|
||||
db.add(run)
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.get("/playbooks/runs", summary="List playbook runs")
|
||||
async def list_playbook_runs(
|
||||
status: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(PlaybookRun)
|
||||
if status:
|
||||
stmt = stmt.where(PlaybookRun.status == status)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(PlaybookRun.hunt_id == hunt_id)
|
||||
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(PlaybookRun.created_at)).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"runs": [_run_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.get("/playbooks/runs/{run_id}", summary="Get playbook run detail")
|
||||
async def get_playbook_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
# Also include the template steps
|
||||
template = get_playbook_template(run.playbook_name)
|
||||
result = _run_to_dict(run)
|
||||
result["steps"] = template["steps"] if template else []
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/complete-step", summary="Complete current playbook step")
|
||||
async def complete_step(
|
||||
run_id: str, body: StepComplete, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
if run.status != "in-progress":
|
||||
raise HTTPException(status_code=400, detail="Run is not in progress")
|
||||
|
||||
step_results = list(run.step_results or [])
|
||||
step_results.append({
|
||||
"step": run.current_step,
|
||||
"status": body.status,
|
||||
"notes": body.notes,
|
||||
"completed_at": _utcnow().isoformat(),
|
||||
})
|
||||
run.step_results = step_results
|
||||
|
||||
if run.current_step >= run.total_steps:
|
||||
run.status = "completed"
|
||||
run.completed_at = _utcnow()
|
||||
else:
|
||||
run.current_step += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/abort", summary="Abort a playbook run")
|
||||
async def abort_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
run.status = "aborted"
|
||||
run.completed_at = _utcnow()
|
||||
await db.commit()
|
||||
return _run_to_dict(run)
|
||||
Reference in New Issue
Block a user