mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
265
backend/app/api/routes/agent_v2.py
Normal file
265
backend/app/api/routes/agent_v2.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
global _agent
|
||||
if _agent is None:
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
query: str = Field(..., max_length=4000, description="Analyst question")
|
||||
dataset_name: str | None = None
|
||||
artifact_type: str | None = None
|
||||
host_identifier: str | None = None
|
||||
data_summary: str | None = None
|
||||
conversation_history: list[dict] | None = None
|
||||
active_hypotheses: list[str] | None = None
|
||||
annotations_summary: str | None = None
|
||||
enrichment_summary: str | None = None
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
sans_references: list[str] = []
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponseModel,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance with auto-routed model selection. "
|
||||
"Supports quick (fast), deep (70B), and debate (multi-model) modes.",
|
||||
)
|
||||
async def agent_assist(
|
||||
request: AssistRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
active_hypotheses=request.active_hypotheses or [],
|
||||
annotations_summary=request.annotations_summary,
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
|
||||
# Persist conversation
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db, conv_id, request, response
|
||||
)
|
||||
|
||||
return AssistResponseModel(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
perspectives=[
|
||||
{
|
||||
"role": p.role,
|
||||
"content": p.content,
|
||||
"model_used": p.model_used,
|
||||
"node_used": p.node_used,
|
||||
"latency_ms": p.latency_ms,
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Agent error: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist/stream",
|
||||
summary="Stream agent response",
|
||||
description="Stream tokens via SSE for real-time display.",
|
||||
)
|
||||
async def agent_assist_stream(request: AssistRequest):
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
mode="quick", # streaming only supports quick mode
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
async for token in agent.assist_stream(context):
|
||||
yield f"data: {json.dumps({'token': token})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent and node health",
|
||||
description="Returns availability of all LLM nodes and the cluster.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
nodes = await check_all_nodes()
|
||||
rag_health = await sans_rag.health_check()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"nodes": nodes,
|
||||
"rag": rag_health,
|
||||
"default_models": {
|
||||
"fast": settings.DEFAULT_FAST_MODEL,
|
||||
"heavy": settings.DEFAULT_HEAVY_MODEL,
|
||||
"code": settings.DEFAULT_CODE_MODEL,
|
||||
"vision": settings.DEFAULT_VISION_MODEL,
|
||||
"embedding": settings.DEFAULT_EMBEDDING_MODEL,
|
||||
},
|
||||
"config": {
|
||||
"max_tokens": settings.AGENT_MAX_TOKENS,
|
||||
"temperature": settings.AGENT_TEMPERATURE,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List all available models",
|
||||
description="Returns the full model registry with capabilities and node assignments.",
|
||||
)
|
||||
async def list_models():
|
||||
return {
|
||||
"models": registry.to_dict(),
|
||||
"total": len(registry.models),
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
db: AsyncSession,
|
||||
conversation_id: str | None,
|
||||
request: AssistRequest,
|
||||
response: AgentResponse,
|
||||
) -> str:
|
||||
"""Save user message and agent response to the database."""
|
||||
if conversation_id:
|
||||
# Find existing conversation
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
conv = Conversation(id=conversation_id, hunt_id=request.hunt_id)
|
||||
db.add(conv)
|
||||
else:
|
||||
conv = Conversation(
|
||||
title=request.query[:100],
|
||||
hunt_id=request.hunt_id,
|
||||
)
|
||||
db.add(conv)
|
||||
await db.flush()
|
||||
|
||||
# User message
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.query,
|
||||
)
|
||||
db.add(user_msg)
|
||||
|
||||
# Agent message
|
||||
agent_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="agent",
|
||||
content=response.guidance,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
response_meta={
|
||||
"confidence": response.confidence,
|
||||
"pivots": response.suggested_pivots,
|
||||
"filters": response.suggested_filters,
|
||||
"sans_refs": response.sans_references,
|
||||
},
|
||||
)
|
||||
db.add(agent_msg)
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
311
backend/app/api/routes/annotations.py
Normal file
311
backend/app/api/routes/annotations.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""API routes for annotations and hypotheses."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Annotation, Hypothesis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["annotations"])
|
||||
|
||||
|
||||
# ── Annotation models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
row_id: int | None = None
|
||||
dataset_id: str | None = None
|
||||
text: str = Field(..., max_length=2000)
|
||||
severity: str = Field(default="info") # info|low|medium|high|critical
|
||||
tag: str | None = None # suspicious|benign|needs-review
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
text: str | None = None
|
||||
severity: str | None = None
|
||||
tag: str | None = None
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
id: str
|
||||
row_id: int | None
|
||||
dataset_id: str | None
|
||||
author_id: str | None
|
||||
text: str
|
||||
severity: str
|
||||
tag: str | None
|
||||
highlight_color: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
annotations: list[AnnotationResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Hypothesis models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HypothesisCreate(BaseModel):
|
||||
hunt_id: str | None = None
|
||||
title: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str = Field(default="draft")
|
||||
|
||||
|
||||
class HypothesisUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str | None = None # draft|active|confirmed|rejected
|
||||
evidence_row_ids: list[int] | None = None
|
||||
evidence_notes: str | None = None
|
||||
|
||||
|
||||
class HypothesisResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str | None
|
||||
title: str
|
||||
description: str | None
|
||||
mitre_technique: str | None
|
||||
status: str
|
||||
evidence_row_ids: list | None
|
||||
evidence_notes: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class HypothesisListResponse(BaseModel):
|
||||
hypotheses: list[HypothesisResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Annotation routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
ann_router = APIRouter(prefix="/api/annotations")
|
||||
|
||||
|
||||
@ann_router.post("", response_model=AnnotationResponse, summary="Create annotation")
|
||||
async def create_annotation(body: AnnotationCreate, db: AsyncSession = Depends(get_db)):
|
||||
ann = Annotation(
|
||||
row_id=body.row_id,
|
||||
dataset_id=body.dataset_id,
|
||||
text=body.text,
|
||||
severity=body.severity,
|
||||
tag=body.tag,
|
||||
highlight_color=body.highlight_color,
|
||||
)
|
||||
db.add(ann)
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.get("", response_model=AnnotationListResponse, summary="List annotations")
|
||||
async def list_annotations(
|
||||
dataset_id: str | None = Query(None),
|
||||
row_id: int | None = Query(None),
|
||||
tag: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Annotation).order_by(Annotation.created_at.desc())
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Annotation.dataset_id == dataset_id)
|
||||
if row_id:
|
||||
stmt = stmt.where(Annotation.row_id == row_id)
|
||||
if tag:
|
||||
stmt = stmt.where(Annotation.tag == tag)
|
||||
if severity:
|
||||
stmt = stmt.where(Annotation.severity == severity)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
annotations = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Annotation.id))
|
||||
if dataset_id:
|
||||
count_stmt = count_stmt.where(Annotation.dataset_id == dataset_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return AnnotationListResponse(
|
||||
annotations=[
|
||||
AnnotationResponse(
|
||||
id=a.id, row_id=a.row_id, dataset_id=a.dataset_id,
|
||||
author_id=a.author_id, text=a.text, severity=a.severity,
|
||||
tag=a.tag, highlight_color=a.highlight_color,
|
||||
created_at=a.created_at.isoformat(), updated_at=a.updated_at.isoformat(),
|
||||
)
|
||||
for a in annotations
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@ann_router.put("/{annotation_id}", response_model=AnnotationResponse, summary="Update annotation")
|
||||
async def update_annotation(
|
||||
annotation_id: str, body: AnnotationUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
if body.text is not None:
|
||||
ann.text = body.text
|
||||
if body.severity is not None:
|
||||
ann.severity = body.severity
|
||||
if body.tag is not None:
|
||||
ann.tag = body.tag
|
||||
if body.highlight_color is not None:
|
||||
ann.highlight_color = body.highlight_color
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.delete("/{annotation_id}", summary="Delete annotation")
|
||||
async def delete_annotation(annotation_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
await db.delete(ann)
|
||||
return {"message": "Annotation deleted", "id": annotation_id}
|
||||
|
||||
|
||||
# ── Hypothesis routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
hyp_router = APIRouter(prefix="/api/hypotheses")
|
||||
|
||||
|
||||
@hyp_router.post("", response_model=HypothesisResponse, summary="Create hypothesis")
|
||||
async def create_hypothesis(body: HypothesisCreate, db: AsyncSession = Depends(get_db)):
|
||||
hyp = Hypothesis(
|
||||
hunt_id=body.hunt_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
mitre_technique=body.mitre_technique,
|
||||
status=body.status,
|
||||
)
|
||||
db.add(hyp)
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("", response_model=HypothesisListResponse, summary="List hypotheses")
|
||||
async def list_hypotheses(
|
||||
hunt_id: str | None = Query(None),
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hypothesis).order_by(Hypothesis.updated_at.desc())
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
if status:
|
||||
stmt = stmt.where(Hypothesis.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hyps = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hypothesis.id))
|
||||
if hunt_id:
|
||||
count_stmt = count_stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HypothesisListResponse(
|
||||
hypotheses=[
|
||||
HypothesisResponse(
|
||||
id=h.id, hunt_id=h.hunt_id, title=h.title,
|
||||
description=h.description, mitre_technique=h.mitre_technique,
|
||||
status=h.status, evidence_row_ids=h.evidence_row_ids,
|
||||
evidence_notes=h.evidence_notes,
|
||||
created_at=h.created_at.isoformat(), updated_at=h.updated_at.isoformat(),
|
||||
)
|
||||
for h in hyps
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("/{hypothesis_id}", response_model=HypothesisResponse, summary="Get hypothesis")
|
||||
async def get_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.put("/{hypothesis_id}", response_model=HypothesisResponse, summary="Update hypothesis")
|
||||
async def update_hypothesis(
|
||||
hypothesis_id: str, body: HypothesisUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
if body.title is not None:
|
||||
hyp.title = body.title
|
||||
if body.description is not None:
|
||||
hyp.description = body.description
|
||||
if body.mitre_technique is not None:
|
||||
hyp.mitre_technique = body.mitre_technique
|
||||
if body.status is not None:
|
||||
hyp.status = body.status
|
||||
if body.evidence_row_ids is not None:
|
||||
hyp.evidence_row_ids = body.evidence_row_ids
|
||||
if body.evidence_notes is not None:
|
||||
hyp.evidence_notes = body.evidence_notes
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.delete("/{hypothesis_id}", summary="Delete hypothesis")
|
||||
async def delete_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
await db.delete(hyp)
|
||||
return {"message": "Hypothesis deleted", "id": hypothesis_id}
|
||||
197
backend/app/api/routes/auth.py
Normal file
197
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
from app.services.auth import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_token_pair,
|
||||
decode_token,
|
||||
get_current_user,
|
||||
TokenPair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str = Field(..., min_length=3, max_length=64)
|
||||
email: str = Field(..., max_length=256)
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
role: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
user: UserResponse
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=AuthResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Register a new user",
|
||||
)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
# Check for existing username
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Username already taken",
|
||||
)
|
||||
|
||||
# Check for existing email
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
logger.info(f"New user registered: {user.username} ({user.id})")
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=AuthResponse,
|
||||
summary="Login with username and password",
|
||||
)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled",
|
||||
)
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=TokenPair,
|
||||
summary="Refresh access token",
|
||||
)
|
||||
async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
token_data = decode_token(body.refresh_token)
|
||||
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user",
|
||||
)
|
||||
|
||||
return create_token_pair(user.id, user.role)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get current user profile",
|
||||
)
|
||||
async def get_profile(user: User = Depends(get_current_user)):
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
83
backend/app/api/routes/correlation.py
Normal file
83
backend/app/api/routes/correlation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""API routes for cross-hunt correlation analysis."""
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.correlation import correlation_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/correlation", tags=["correlation"])
|
||||
|
||||
|
||||
class CorrelateRequest(BaseModel):
|
||||
hunt_ids: list[str] = Field(
|
||||
...,
|
||||
min_length=2,
|
||||
max_length=20,
|
||||
description="List of hunt IDs to correlate",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/analyze",
|
||||
summary="Run correlation analysis across hunts",
|
||||
description="Find shared IOCs, overlapping time windows, common MITRE techniques, "
|
||||
"and host patterns across the specified hunts.",
|
||||
)
|
||||
async def correlate_hunts(
|
||||
body: CorrelateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await correlation_engine.correlate_hunts(body.hunt_ids, db)
|
||||
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps],
|
||||
"host_overlaps": result.host_overlaps,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/all",
|
||||
summary="Correlate all hunts",
|
||||
description="Run correlation across all hunts in the system.",
|
||||
)
|
||||
async def correlate_all(db: AsyncSession = Depends(get_db)):
|
||||
result = await correlation_engine.correlate_all(db)
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps[:20]],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps[:10]],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps[:10]],
|
||||
"host_overlaps": result.host_overlaps[:10],
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ioc/{ioc_value}",
|
||||
summary="Find IOC across all hunts",
|
||||
description="Search for a specific IOC value across all datasets and hunts.",
|
||||
)
|
||||
async def find_ioc(
|
||||
ioc_value: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
occurrences = await correlation_engine.find_ioc_across_hunts(ioc_value, db)
|
||||
return {
|
||||
"ioc_value": ioc_value,
|
||||
"occurrences": occurrences,
|
||||
"total": len(occurrences),
|
||||
"unique_hunts": len(set(o["hunt_id"] for o in occurrences if o.get("hunt_id"))),
|
||||
}
|
||||
295
backend/app/api/routes/datasets.py
Normal file
295
backend/app/api/routes/datasets.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""API routes for dataset upload, listing, and management."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
normalize_columns,
|
||||
normalize_rows,
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
filename: str
|
||||
source_tool: str | None = None
|
||||
row_count: int
|
||||
column_schema: dict | None = None
|
||||
normalized_columns: dict | None = None
|
||||
ioc_columns: dict | None = None
|
||||
file_size_bytes: int
|
||||
encoding: str | None = None
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
datasets: list[DatasetSummary]
|
||||
total: int
|
||||
|
||||
|
||||
class RowsResponse(BaseModel):
|
||||
rows: list[dict]
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
row_count: int
|
||||
columns: list[str]
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
name: str | None = Query(None, description="Display name for the dataset"),
|
||||
source_tool: str | None = Query(None, description="Source tool (e.g., velociraptor)"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type '{ext}' not allowed. Accepted: {', '.join(ALLOWED_EXTENSIONS)}",
|
||||
)
|
||||
|
||||
# Read file bytes
|
||||
raw_bytes = await file.read()
|
||||
if len(raw_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="File is empty")
|
||||
|
||||
if len(raw_bytes) > settings.max_upload_bytes:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Max size: {settings.MAX_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
|
||||
# Parse CSV
|
||||
try:
|
||||
rows, metadata = parse_csv_bytes(raw_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"CSV parse error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to parse CSV: {str(e)}. Check encoding and format.",
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=422, detail="CSV file contains no data rows")
|
||||
|
||||
columns: list[str] = metadata["columns"]
|
||||
column_types: dict = metadata["column_types"]
|
||||
|
||||
# Normalize columns
|
||||
column_mapping = normalize_columns(columns)
|
||||
normalized = normalize_rows(rows, column_mapping)
|
||||
|
||||
# Detect IOCs
|
||||
ioc_columns = detect_ioc_columns(columns, column_types, column_mapping)
|
||||
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
filename=file.filename,
|
||||
source_tool=source_tool,
|
||||
row_count=len(rows),
|
||||
column_schema=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
file_size_bytes=len(raw_bytes),
|
||||
encoding=metadata["encoding"],
|
||||
delimiter=metadata["delimiter"],
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
dataset_id=dataset.id,
|
||||
rows=rows,
|
||||
normalized_rows=normalized,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
)
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
row_count=len(rows),
|
||||
columns=columns,
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
datasets = await repo.list_datasets(hunt_id=hunt_id, limit=limit, offset=offset)
|
||||
total = await repo.count_datasets(hunt_id=hunt_id)
|
||||
|
||||
return DatasetListResponse(
|
||||
datasets=[
|
||||
DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
for ds in datasets
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}",
|
||||
response_model=DatasetSummary,
|
||||
summary="Get dataset details",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}/rows",
|
||||
response_model=RowsResponse,
|
||||
summary="Get dataset rows",
|
||||
)
|
||||
async def get_dataset_rows(
|
||||
dataset_id: str,
|
||||
limit: int = Query(1000, ge=1, le=10000),
|
||||
offset: int = Query(0, ge=0),
|
||||
normalized: bool = Query(False, description="Return normalized column names"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=limit, offset=offset)
|
||||
total = await repo.count_rows(dataset_id)
|
||||
|
||||
return RowsResponse(
|
||||
rows=[
|
||||
(r.normalized_data if normalized and r.normalized_data else r.data)
|
||||
for r in rows
|
||||
],
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{dataset_id}",
|
||||
summary="Delete a dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
220
backend/app/api/routes/enrichment.py
Normal file
220
backend/app/api/routes/enrichment.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""API routes for IOC enrichment."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.enrichment import (
|
||||
enrichment_engine,
|
||||
IOCType,
|
||||
Verdict,
|
||||
EnrichmentResultData,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/enrichment", tags=["enrichment"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichIOCRequest(BaseModel):
|
||||
ioc_value: str = Field(..., max_length=2048, description="IOC value to enrich")
|
||||
ioc_type: str = Field(..., description="IOC type: ip, domain, hash_md5, hash_sha1, hash_sha256, url")
|
||||
skip_cache: bool = False
|
||||
|
||||
|
||||
class EnrichBatchRequest(BaseModel):
|
||||
iocs: list[dict] = Field(
|
||||
...,
|
||||
description="List of {value, type} pairs",
|
||||
max_length=50,
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentResultResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
source: str
|
||||
verdict: str
|
||||
score: float
|
||||
tags: list[str] = []
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
raw_data: dict = {}
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
class EnrichIOCResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
results: list[EnrichmentResultResponse]
|
||||
overall_verdict: str
|
||||
overall_score: float
|
||||
|
||||
|
||||
class EnrichBatchResponse(BaseModel):
|
||||
results: dict[str, list[EnrichmentResultResponse]]
|
||||
total_enriched: int
|
||||
|
||||
|
||||
def _to_response(r: EnrichmentResultData) -> EnrichmentResultResponse:
|
||||
return EnrichmentResultResponse(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
last_seen=r.last_seen,
|
||||
raw_data=r.raw_data,
|
||||
error=r.error,
|
||||
latency_ms=r.latency_ms,
|
||||
)
|
||||
|
||||
|
||||
def _compute_overall(results: list[EnrichmentResultData]) -> tuple[str, float]:
|
||||
"""Compute overall verdict from multiple provider results."""
|
||||
if not results:
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
verdicts = [r.verdict for r in results if r.verdict != Verdict.ERROR]
|
||||
if not verdicts:
|
||||
return Verdict.ERROR.value, 0.0
|
||||
|
||||
if Verdict.MALICIOUS in verdicts:
|
||||
return Verdict.MALICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.SUSPICIOUS in verdicts:
|
||||
return Verdict.SUSPICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.CLEAN in verdicts:
|
||||
return Verdict.CLEAN.value, 0.0
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ioc",
|
||||
response_model=EnrichIOCResponse,
|
||||
summary="Enrich a single IOC",
|
||||
description="Query all configured providers for an IOC (IP, hash, domain, URL).",
|
||||
)
|
||||
async def enrich_ioc(
|
||||
body: EnrichIOCRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
ioc_type = IOCType(body.ioc_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid IOC type: {body.ioc_type}. Valid: {[t.value for t in IOCType]}",
|
||||
)
|
||||
|
||||
results = await enrichment_engine.enrich_ioc(
|
||||
body.ioc_value, ioc_type, db=db, skip_cache=body.skip_cache,
|
||||
)
|
||||
|
||||
overall_verdict, overall_score = _compute_overall(results)
|
||||
|
||||
return EnrichIOCResponse(
|
||||
ioc_value=body.ioc_value,
|
||||
ioc_type=body.ioc_type,
|
||||
results=[_to_response(r) for r in results],
|
||||
overall_verdict=overall_verdict,
|
||||
overall_score=overall_score,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=EnrichBatchResponse,
|
||||
summary="Enrich a batch of IOCs",
|
||||
description="Enrich up to 50 IOCs at once across all providers.",
|
||||
)
|
||||
async def enrich_batch(
|
||||
body: EnrichBatchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
iocs = []
|
||||
for item in body.iocs:
|
||||
try:
|
||||
ioc_type = IOCType(item["type"])
|
||||
iocs.append((item["value"], ioc_type))
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not iocs:
|
||||
raise HTTPException(status_code=400, detail="No valid IOCs provided")
|
||||
|
||||
all_results = await enrichment_engine.enrich_batch(iocs, db=db)
|
||||
|
||||
return EnrichBatchResponse(
|
||||
results={
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
total_enriched=len(all_results),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/dataset/{dataset_id}",
|
||||
summary="Auto-enrich IOCs in a dataset",
|
||||
description="Automatically extract and enrich IOCs from a dataset's IOC columns.",
|
||||
)
|
||||
async def enrich_dataset(
|
||||
dataset_id: str,
|
||||
max_iocs: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if not dataset.ioc_columns:
|
||||
return {"message": "No IOC columns detected in this dataset", "results": {}}
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=1000)
|
||||
row_data = [r.data for r in rows]
|
||||
|
||||
all_results = await enrichment_engine.enrich_dataset_iocs(
|
||||
rows=row_data,
|
||||
ioc_columns=dataset.ioc_columns,
|
||||
db=db,
|
||||
max_iocs=max_iocs,
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset.name,
|
||||
"ioc_columns": dataset.ioc_columns,
|
||||
"results": {
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
"total_enriched": len(all_results),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
summary="Enrichment engine status",
|
||||
description="Check which providers are configured and available.",
|
||||
)
|
||||
async def enrichment_status():
|
||||
return enrichment_engine.status()
|
||||
158
backend/app/api/routes/hunts.py
Normal file
158
backend/app/api/routes/hunts.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""API routes for hunt management."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
owner_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
dataset_count: int = 0
|
||||
hypothesis_count: int = 0
|
||||
|
||||
|
||||
class HuntListResponse(BaseModel):
|
||||
hunts: list[HuntResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
|
||||
hunt = Hunt(name=body.name, description=body.description)
|
||||
db.add(hunt)
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=HuntListResponse, summary="List hunts")
|
||||
async def list_hunts(
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
|
||||
if status:
|
||||
stmt = stmt.where(Hunt.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hunts = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hunt.id))
|
||||
if status:
|
||||
count_stmt = count_stmt.where(Hunt.status == status)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HuntListResponse(
|
||||
hunts=[
|
||||
HuntResponse(
|
||||
id=h.id,
|
||||
name=h.name,
|
||||
description=h.description,
|
||||
status=h.status,
|
||||
owner_id=h.owner_id,
|
||||
created_at=h.created_at.isoformat(),
|
||||
updated_at=h.updated_at.isoformat(),
|
||||
dataset_count=len(h.datasets) if h.datasets else 0,
|
||||
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
|
||||
)
|
||||
for h in hunts
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
|
||||
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
|
||||
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
if body.name is not None:
|
||||
hunt.name = body.name
|
||||
if body.description is not None:
|
||||
hunt.description = body.description
|
||||
if body.status is not None:
|
||||
hunt.status = body.status
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{hunt_id}", summary="Delete a hunt")
|
||||
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
await db.delete(hunt)
|
||||
return {"message": "Hunt deleted", "id": hunt_id}
|
||||
257
backend/app/api/routes/keywords.py
Normal file
257
backend/app/api/routes/keywords.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ThemeUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
color: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class KeywordOut(BaseModel):
|
||||
id: int
|
||||
theme_id: str
|
||||
value: str
|
||||
is_regex: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class ThemeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
color: str
|
||||
enabled: bool
|
||||
is_builtin: bool
|
||||
created_at: str
|
||||
keyword_count: int
|
||||
keywords: list[KeywordOut]
|
||||
|
||||
|
||||
class ThemeListResponse(BaseModel):
|
||||
themes: list[ThemeOut]
|
||||
total: int
|
||||
|
||||
|
||||
class KeywordCreate(BaseModel):
|
||||
value: str = Field(..., min_length=1, max_length=256)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KeywordBulkCreate(BaseModel):
|
||||
values: list[str] = Field(..., min_items=1)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
total_hits: int
|
||||
hits: list[ScanHit]
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
return ThemeOut(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
color=t.color,
|
||||
enabled=t.enabled,
|
||||
is_builtin=t.is_builtin,
|
||||
created_at=t.created_at.isoformat(),
|
||||
keyword_count=len(t.keywords),
|
||||
keywords=[
|
||||
KeywordOut(
|
||||
id=k.id,
|
||||
theme_id=k.theme_id,
|
||||
value=k.value,
|
||||
is_regex=k.is_regex,
|
||||
created_at=k.created_at.isoformat(),
|
||||
)
|
||||
for k in t.keywords
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme.name = body.name
|
||||
if body.color is not None:
|
||||
theme.color = body.color
|
||||
if body.enabled is not None:
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
added = 0
|
||||
for val in body.values:
|
||||
val = val.strip()
|
||||
if not val:
|
||||
continue
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
67
backend/app/api/routes/reports.py
Normal file
67
backend/app/api/routes/reports.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""API routes for report generation and export."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import HTMLResponse, PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.reports import report_generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/reports", tags=["reports"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}",
|
||||
summary="Generate hunt investigation report",
|
||||
description="Generate a comprehensive report for a hunt. Supports JSON, HTML, and CSV formats.",
|
||||
)
|
||||
async def generate_hunt_report(
|
||||
hunt_id: str,
|
||||
format: str = Query("json", description="Report format: json, html, csv"),
|
||||
include_rows: bool = Query(False, description="Include raw data rows"),
|
||||
max_rows: int = Query(500, ge=0, le=5000, description="Max rows to include"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format=format,
|
||||
include_rows=include_rows, max_rows=max_rows,
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
if format == "html":
|
||||
return HTMLResponse(content=result, headers={
|
||||
"Content-Disposition": f"inline; filename=threathunt_report_{hunt_id}.html",
|
||||
})
|
||||
elif format == "csv":
|
||||
return PlainTextResponse(content=result, media_type="text/csv", headers={
|
||||
"Content-Disposition": f"attachment; filename=threathunt_report_{hunt_id}.csv",
|
||||
})
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}/summary",
|
||||
summary="Quick hunt summary",
|
||||
description="Get a lightweight summary of the hunt for dashboard display.",
|
||||
)
|
||||
async def hunt_summary(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format="json", include_rows=False,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return {
|
||||
"hunt": result.get("hunt"),
|
||||
"summary": result.get("summary"),
|
||||
}
|
||||
Reference in New Issue
Block a user