mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates. - Added SavedSearches component for managing bookmarked queries and recurring scans. - Introduced TimelineView for visualizing forensic event timelines with zoomable charts. - Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency. - Updated frontend components for better user experience and performance optimizations. - Documented changes in update log for future reference.
This commit is contained in:
@@ -1,170 +0,0 @@
|
||||
"""API routes for analyst-assist agent."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
|
||||
from app.agents.config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance (lazy-loaded)
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
"""Get or create the agent instance."""
|
||||
global _agent
|
||||
if _agent is None:
|
||||
if not AgentConfig.is_agent_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Analyst-assist agent is not configured. "
|
||||
"Please configure an LLM provider.",
|
||||
)
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
"""Request for agent assistance."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: str | None = Field(
|
||||
None, description="Name of CSV dataset being analyzed"
|
||||
)
|
||||
artifact_type: str | None = Field(
|
||||
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
|
||||
)
|
||||
host_identifier: str | None = Field(
|
||||
None, description="Host name, IP address, or identifier"
|
||||
)
|
||||
data_summary: str | None = Field(
|
||||
None, description="Brief summary or context about the uploaded data"
|
||||
)
|
||||
conversation_history: list[dict] | None = Field(
|
||||
None, description="Previous messages for context"
|
||||
)
|
||||
|
||||
|
||||
class AssistResponse(BaseModel):
|
||||
"""Response with agent guidance."""
|
||||
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponse,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
|
||||
"Agent provides advisory guidance only - no execution.",
|
||||
)
|
||||
async def agent_assist(request: AssistRequest) -> AssistResponse:
|
||||
"""Provide analyst-assist guidance on artifact data.
|
||||
|
||||
The agent will:
|
||||
- Explain and interpret the provided data context
|
||||
- Suggest analytical pivots the analyst might explore
|
||||
- Suggest data filters or queries that might be useful
|
||||
- Highlight assumptions, limitations, and caveats
|
||||
|
||||
The agent will NOT:
|
||||
- Execute any tools or actions
|
||||
- Escalate findings to alerts
|
||||
- Modify any data or schema
|
||||
- Make autonomous decisions
|
||||
|
||||
Args:
|
||||
request: Assistance request with query and context
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not configured (503) or request fails
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
|
||||
# Build context
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
)
|
||||
|
||||
# Get guidance
|
||||
response = await agent.assist(context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted analyst with query: {request.query[:50]}... "
|
||||
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
|
||||
)
|
||||
|
||||
return AssistResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Agent error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Agent unavailable: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in agent_assist: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error generating guidance. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent health",
|
||||
description="Check if agent is configured and ready to assist.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
"""Check agent availability and configuration.
|
||||
|
||||
Returns:
|
||||
Health status with configuration details
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": provider_type,
|
||||
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
|
||||
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
|
||||
}
|
||||
except HTTPException:
|
||||
return {
|
||||
"status": "unavailable",
|
||||
"reason": "No LLM provider configured",
|
||||
"configured_providers": {
|
||||
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
|
||||
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
|
||||
"online": bool(AgentConfig.ONLINE_API_KEY),
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
"""API routes for analyst-assist agent v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.db.models import Conversation, Message, Dataset, KeywordTheme
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# Request / Response models
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
execution_preference: str = Field(default="auto", description="auto | force | off")
|
||||
learning_mode: bool = False
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
execution: dict | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
|
||||
POLICY_QUERY_TERMS = {
|
||||
"policy", "violating", "violation", "browser history", "web history",
|
||||
"domain", "domains", "adult", "gambling", "piracy", "aup",
|
||||
}
|
||||
WEB_DATASET_HINTS = {
|
||||
"web", "history", "browser", "url", "visited_url", "domain", "title",
|
||||
}
|
||||
|
||||
|
||||
def _is_policy_domain_query(query: str) -> bool:
|
||||
q = (query or "").lower()
|
||||
if not q:
|
||||
return False
|
||||
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
|
||||
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
|
||||
|
||||
def _should_execute_policy_scan(request: AssistRequest) -> bool:
|
||||
pref = (request.execution_preference or "auto").strip().lower()
|
||||
if pref == "off":
|
||||
return False
|
||||
if pref == "force":
|
||||
return True
|
||||
return _is_policy_domain_query(request.query)
|
||||
|
||||
|
||||
def _extract_domain(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(text)
|
||||
if parsed.netloc:
|
||||
return parsed.netloc.lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
|
||||
return m.group(0).lower() if m else None
|
||||
|
||||
|
||||
def _dataset_score(ds: Dataset) -> int:
|
||||
score = 0
|
||||
name = (ds.name or "").lower()
|
||||
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
|
||||
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
|
||||
|
||||
for h in WEB_DATASET_HINTS:
|
||||
if h in name:
|
||||
score += 2
|
||||
if h in cols_l:
|
||||
score += 3
|
||||
if h in norm_vals_l:
|
||||
score += 3
|
||||
|
||||
if "visited_url" in cols_l or "url" in cols_l:
|
||||
score += 8
|
||||
if "user" in cols_l or "username" in cols_l:
|
||||
score += 2
|
||||
if "clientid" in cols_l or "fqdn" in cols_l:
|
||||
score += 2
|
||||
if (ds.row_count or 0) > 0:
|
||||
score += 1
|
||||
|
||||
return score
|
||||
|
||||
|
||||
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
theme_result = await db.execute(
|
||||
select(KeywordTheme).where(
|
||||
KeywordTheme.enabled == True, # noqa: E712
|
||||
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
|
||||
)
|
||||
)
|
||||
themes = list(theme_result.scalars().all())
|
||||
theme_ids = [t.id for t in themes]
|
||||
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
|
||||
|
||||
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
|
||||
if request.hunt_id:
|
||||
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
|
||||
ds_result = await db.execute(ds_query)
|
||||
candidates = list(ds_result.scalars().all())
|
||||
|
||||
if request.dataset_name:
|
||||
needle = request.dataset_name.lower().strip()
|
||||
candidates = [d for d in candidates if needle in (d.name or "").lower()]
|
||||
|
||||
scored = sorted(
|
||||
((d, _dataset_score(d)) for d in candidates),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
selected = [d for d, s in scored if s > 0][:8]
|
||||
dataset_ids = [d.id for d in selected]
|
||||
|
||||
if not dataset_ids:
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": 0,
|
||||
"dataset_names": [],
|
||||
"total_hits": 0,
|
||||
"policy_hits": 0,
|
||||
"top_user_hosts": [],
|
||||
"top_domains": [],
|
||||
"sample_hits": [],
|
||||
"note": "No suitable browser/web-history datasets found in current scope.",
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=dataset_ids,
|
||||
theme_ids=theme_ids or None,
|
||||
scan_hunts=False,
|
||||
scan_annotations=False,
|
||||
scan_messages=False,
|
||||
)
|
||||
hits = result.get("hits", [])
|
||||
|
||||
user_host_counter = Counter()
|
||||
domain_counter = Counter()
|
||||
|
||||
for h in hits:
|
||||
user = h.get("username") or "(unknown-user)"
|
||||
host = h.get("hostname") or "(unknown-host)"
|
||||
user_host_counter[f"{user}|{host}"] += 1
|
||||
|
||||
dom = _extract_domain(h.get("matched_value"))
|
||||
if dom:
|
||||
domain_counter[dom] += 1
|
||||
|
||||
top_user_hosts = [
|
||||
{"user_host": k, "count": v}
|
||||
for k, v in user_host_counter.most_common(10)
|
||||
]
|
||||
top_domains = [
|
||||
{"domain": k, "count": v}
|
||||
for k, v in domain_counter.most_common(10)
|
||||
]
|
||||
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": len(dataset_ids),
|
||||
"dataset_names": [d.name for d in selected],
|
||||
"total_hits": int(result.get("total_hits", 0)),
|
||||
"policy_hits": int(result.get("total_hits", 0)),
|
||||
"rows_scanned": int(result.get("rows_scanned", 0)),
|
||||
"top_user_hosts": top_user_hosts,
|
||||
"top_domains": top_domains,
|
||||
"sample_hits": hits[:20],
|
||||
}
|
||||
|
||||
|
||||
# Routes
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -84,6 +252,76 @@ async def agent_assist(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
# Deterministic execution mode for policy-domain investigations.
|
||||
if _should_execute_policy_scan(request):
|
||||
t0 = time.monotonic()
|
||||
exec_payload = await _run_policy_domain_execution(request, db)
|
||||
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
policy_hits = exec_payload.get("policy_hits", 0)
|
||||
datasets_scanned = exec_payload.get("datasets_scanned", 0)
|
||||
|
||||
if policy_hits > 0:
|
||||
guidance = (
|
||||
f"Policy-violation scan complete: {policy_hits} hits across "
|
||||
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
|
||||
f"in execution results for triage."
|
||||
)
|
||||
confidence = 0.95
|
||||
caveats = "Keyword-based matching can include false positives; validate with full URL context."
|
||||
else:
|
||||
guidance = (
|
||||
f"No policy-violation hits found in current scope "
|
||||
f"({datasets_scanned} dataset(s) scanned)."
|
||||
)
|
||||
confidence = 0.9
|
||||
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
|
||||
|
||||
response = AssistResponseModel(
|
||||
guidance=guidance,
|
||||
confidence=confidence,
|
||||
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
|
||||
suggested_filters=[
|
||||
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
|
||||
"username != null",
|
||||
"hostname != null",
|
||||
],
|
||||
caveats=caveats,
|
||||
reasoning=(
|
||||
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
|
||||
if _is_policy_domain_query(request.query)
|
||||
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
|
||||
),
|
||||
sans_references=["SANS FOR508", "SANS SEC504"],
|
||||
model_used="execution:keyword_scanner",
|
||||
node_used="local",
|
||||
latency_ms=latency_ms,
|
||||
execution=exec_payload,
|
||||
)
|
||||
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db,
|
||||
conv_id,
|
||||
request,
|
||||
AgentResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
),
|
||||
)
|
||||
response.conversation_id = conv_id
|
||||
|
||||
return response
|
||||
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
@@ -97,6 +335,7 @@ async def agent_assist(
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
learning_mode=request.learning_mode,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
@@ -129,6 +368,7 @@ async def agent_assist(
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
execution=None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
@@ -208,7 +448,7 @@ async def list_models():
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
# Conversation persistence
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
@@ -263,3 +503,4 @@ async def _persist_conversation(
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
|
||||
|
||||
@@ -381,6 +381,10 @@ async def submit_job(
|
||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||
)
|
||||
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
job = job_queue.submit(jt, **params)
|
||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
hashed_password=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
if not user or not user.hashed_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
if not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import ProcessingTask
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
@@ -18,15 +19,20 @@ from app.services.normalizer import (
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
from app.services.artifact_classifier import classify_artifact, get_artifact_category
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
from app.services.host_inventory import inventory_cache
|
||||
from app.services.scanner import keyword_scan_cache
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
# -- Response models --
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
artifact_type: str | None = None
|
||||
processing_status: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
artifact_type: str | None = None
|
||||
processing_status: str
|
||||
jobs_queued: list[str]
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# -- Routes --
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
@@ -87,7 +98,7 @@ async def upload_dataset(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
@@ -136,7 +147,12 @@ async def upload_dataset(
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
# Classify artifact type from column headers
|
||||
artifact_type = classify_artifact(columns)
|
||||
artifact_category = get_artifact_category(artifact_type)
|
||||
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
|
||||
|
||||
# Store in DB with processing_status = "processing"
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
@@ -152,6 +168,8 @@ async def upload_dataset(
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
@@ -162,9 +180,88 @@ async def upload_dataset(
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
|
||||
f"artifact={artifact_type}"
|
||||
)
|
||||
|
||||
# -- Queue full processing pipeline --
|
||||
jobs_queued = []
|
||||
|
||||
task_rows: list[ProcessingTask] = []
|
||||
|
||||
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||
jobs_queued.append("triage")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=triage_job.id,
|
||||
stage="triage",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 2. Anomaly detection (embedding-based outlier detection)
|
||||
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||
jobs_queued.append("anomaly")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=anomaly_job.id,
|
||||
stage="anomaly",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 3. AUP keyword scan
|
||||
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||
jobs_queued.append("keyword_scan")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=kw_job.id,
|
||||
stage="keyword_scan",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 4. IOC extraction
|
||||
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||
jobs_queued.append("ioc_extract")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=ioc_job.id,
|
||||
stage="ioc_extract",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 5. Host inventory (network map) - requires hunt_id
|
||||
if hunt_id:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
jobs_queued.append("host_inventory")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=inv_job.id,
|
||||
stage="host_inventory",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
if task_rows:
|
||||
db.add_all(task_rows)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
@@ -173,7 +270,10 @@ async def upload_dataset(
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
jobs_queued=jobs_queued,
|
||||
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
|
||||
)
|
||||
|
||||
|
||||
@@ -208,6 +308,8 @@ async def list_datasets(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -244,6 +346,8 @@ async def get_dataset(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -292,4 +396,5 @@ async def delete_dataset(
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
@@ -8,16 +8,15 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
from app.db.models import Hunt, Dataset, ProcessingTask
|
||||
from app.services.job_queue import job_queue
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
class HuntProgressResponse(BaseModel):
|
||||
hunt_id: str
|
||||
status: str
|
||||
progress_percent: float
|
||||
dataset_total: int
|
||||
dataset_completed: int
|
||||
dataset_processing: int
|
||||
dataset_errors: int
|
||||
active_jobs: int
|
||||
queued_jobs: int
|
||||
network_status: str
|
||||
stages: dict
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
hunt = await db.get(Hunt, hunt_id)
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
ds_rows = await db.execute(
|
||||
select(Dataset.id, Dataset.processing_status)
|
||||
.where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_rows.all()
|
||||
dataset_ids = {row[0] for row in datasets}
|
||||
|
||||
dataset_total = len(datasets)
|
||||
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||
|
||||
jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
task_rows = await db.execute(
|
||||
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||
.where(ProcessingTask.hunt_id == hunt_id)
|
||||
)
|
||||
tasks = task_rows.all()
|
||||
|
||||
task_total = len(tasks)
|
||||
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||
|
||||
active_jobs = max(active_jobs_mem, task_running)
|
||||
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||
|
||||
stage_rollup: dict[str, dict] = {}
|
||||
for stage, status, progress in tasks:
|
||||
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||
bucket["total"] += 1
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
bucket["done"] += 1
|
||||
elif status == "running":
|
||||
bucket["running"] += 1
|
||||
elif status == "queued":
|
||||
bucket["queued"] += 1
|
||||
bucket["progress_sum"] += float(progress or 0.0)
|
||||
|
||||
for stage_name, bucket in stage_rollup.items():
|
||||
total = max(1, bucket["total"])
|
||||
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
network_status = "ready"
|
||||
network_ratio = 1.0
|
||||
elif inventory_cache.is_building(hunt_id):
|
||||
network_status = "building"
|
||||
network_ratio = 0.5
|
||||
else:
|
||||
network_status = "none"
|
||||
network_ratio = 0.0
|
||||
|
||||
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
if task_ratio is None:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
else:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
|
||||
status = "ready"
|
||||
if dataset_total == 0:
|
||||
status = "idle"
|
||||
elif progress_percent < 100:
|
||||
status = "processing"
|
||||
|
||||
stages = {
|
||||
"datasets": {
|
||||
"total": dataset_total,
|
||||
"completed": dataset_completed,
|
||||
"processing": dataset_processing,
|
||||
"errors": dataset_errors,
|
||||
"percent": round(dataset_ratio * 100.0, 1),
|
||||
},
|
||||
"network": {
|
||||
"status": network_status,
|
||||
"percent": round(network_ratio * 100.0, 1),
|
||||
},
|
||||
"jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
"task_total": task_total,
|
||||
"task_done": task_done,
|
||||
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||
},
|
||||
"task_stages": stage_rollup,
|
||||
}
|
||||
|
||||
return HuntProgressResponse(
|
||||
hunt_id=hunt_id,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
dataset_total=dataset_total,
|
||||
dataset_completed=dataset_completed,
|
||||
dataset_processing=dataset_processing,
|
||||
dataset_errors=dataset_errors,
|
||||
active_jobs=active_jobs,
|
||||
queued_jobs=queued_jobs,
|
||||
network_status=network_status,
|
||||
stages=stages,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
dataset_ids: list[str] | None = None
|
||||
theme_ids: list[str] | None = None
|
||||
scan_hunts: bool = False
|
||||
scan_annotations: bool = False
|
||||
scan_messages: bool = False
|
||||
prefer_cache: bool = True
|
||||
force_rescan: bool = False
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
cache_used: bool = False
|
||||
cache_status: str = "miss"
|
||||
cached_at: str | None = None
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||
hits: list[dict] = []
|
||||
total_rows = 0
|
||||
cached_at: str | None = None
|
||||
|
||||
for entry in entries:
|
||||
result = entry["result"]
|
||||
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||
if entry.get("built_at"):
|
||||
if not cached_at or entry["built_at"] > cached_at:
|
||||
cached_at = entry["built_at"]
|
||||
for h in result.get("hits", []):
|
||||
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||
continue
|
||||
hits.append(h)
|
||||
|
||||
return {
|
||||
"total_hits": len(hits),
|
||||
"hits": hits,
|
||||
"rows_scanned": total_rows,
|
||||
"cached_at": cached_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
keyword_scan_cache.clear()
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
keyword_scan_cache.clear()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||
return {
|
||||
"total_hits": 0,
|
||||
"hits": [],
|
||||
"themes_scanned": 0,
|
||||
"keywords_scanned": 0,
|
||||
"rows_scanned": 0,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
can_use_cache = (
|
||||
body.prefer_cache
|
||||
and not body.force_rescan
|
||||
and bool(body.dataset_ids)
|
||||
and not body.scan_hunts
|
||||
and not body.scan_annotations
|
||||
and not body.scan_messages
|
||||
)
|
||||
|
||||
if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
if missing:
|
||||
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + [{"result": partial, "built_at": None}],
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
@@ -251,7 +326,22 @@ async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if entry is not None:
|
||||
result = entry.result
|
||||
return {
|
||||
**result,
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": entry.built_at,
|
||||
}
|
||||
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
146
backend/app/api/routes/mitre.py
Normal file
146
backend/app/api/routes/mitre.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""API routes for MITRE ATT&CK coverage visualization."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
|
||||
|
||||
# Canonical MITRE ATT&CK tactics in kill-chain order
|
||||
TACTICS = [
|
||||
"Reconnaissance", "Resource Development", "Initial Access",
|
||||
"Execution", "Persistence", "Privilege Escalation",
|
||||
"Defense Evasion", "Credential Access", "Discovery",
|
||||
"Lateral Movement", "Collection", "Command and Control",
|
||||
"Exfiltration", "Impact",
|
||||
]
|
||||
|
||||
# Simplified technique-to-tactic mapping (top techniques)
|
||||
TECHNIQUE_TACTIC: dict[str, str] = {
|
||||
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
|
||||
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
|
||||
"T1053": "Persistence", "T1053.005": "Persistence",
|
||||
"T1547": "Persistence", "T1547.001": "Persistence",
|
||||
"T1543": "Persistence", "T1543.003": "Persistence",
|
||||
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
|
||||
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
|
||||
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
|
||||
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
|
||||
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
|
||||
"T1070.004": "Defense Evasion",
|
||||
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
|
||||
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
|
||||
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
|
||||
"T1003": "Credential Access", "T1003.001": "Credential Access",
|
||||
"T1110": "Credential Access", "T1558": "Credential Access",
|
||||
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
|
||||
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
|
||||
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
|
||||
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
|
||||
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
|
||||
"T1570": "Lateral Movement",
|
||||
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
|
||||
"T1071": "Command and Control", "T1071.001": "Command and Control",
|
||||
"T1105": "Command and Control", "T1572": "Command and Control",
|
||||
"T1095": "Command and Control",
|
||||
"T1048": "Exfiltration", "T1041": "Exfiltration",
|
||||
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
|
||||
"T1566": "Initial Access", "T1566.001": "Initial Access",
|
||||
"T1566.002": "Initial Access",
|
||||
"T1190": "Initial Access", "T1133": "Initial Access",
|
||||
"T1195": "Initial Access", "T1195.002": "Initial Access",
|
||||
}
|
||||
|
||||
|
||||
def _get_tactic(technique_id: str) -> str:
|
||||
"""Map a technique ID to its tactic."""
|
||||
tech = technique_id.strip().upper()
|
||||
if tech in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[tech]
|
||||
# Try parent technique
|
||||
if "." in tech:
|
||||
parent = tech.split(".")[0]
|
||||
if parent in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[parent]
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@router.get("/coverage")
|
||||
async def get_mitre_coverage(
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
|
||||
|
||||
techniques: dict[str, dict] = {}
|
||||
|
||||
# Collect from triage results
|
||||
triage_q = select(TriageResult)
|
||||
if hunt_id:
|
||||
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
result = await db.execute(triage_q.limit(500))
|
||||
for t in result.scalars().all():
|
||||
for tech in (t.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
|
||||
|
||||
# Collect from host profiles
|
||||
profile_q = select(HostProfile)
|
||||
if hunt_id:
|
||||
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
|
||||
result = await db.execute(profile_q.limit(200))
|
||||
for p in result.scalars().all():
|
||||
for tech in (p.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
|
||||
|
||||
# Collect from hypotheses
|
||||
hyp_q = select(Hypothesis)
|
||||
if hunt_id:
|
||||
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
|
||||
result = await db.execute(hyp_q.limit(200))
|
||||
for h in result.scalars().all():
|
||||
tech = h.mitre_technique
|
||||
if tech:
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
|
||||
|
||||
# Build tactic-grouped response
|
||||
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
|
||||
tactic_groups["Unknown"] = []
|
||||
for tech in techniques.values():
|
||||
tactic = tech["tactic"]
|
||||
if tactic not in tactic_groups:
|
||||
tactic_groups[tactic] = []
|
||||
tactic_groups[tactic].append(tech)
|
||||
|
||||
total_techniques = len(techniques)
|
||||
total_detections = sum(t["count"] for t in techniques.values())
|
||||
|
||||
return {
|
||||
"tactics": TACTICS,
|
||||
"technique_count": total_techniques,
|
||||
"detection_count": total_detections,
|
||||
"tactic_coverage": {
|
||||
t: {"techniques": techs, "count": len(techs)}
|
||||
for t, techs in tactic_groups.items()
|
||||
if techs
|
||||
},
|
||||
"all_techniques": list(techniques.values()),
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Network topology API - host inventory endpoint."""
|
||||
"""Network topology API - host inventory endpoint with background caching."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.services.host_inventory import build_host_inventory
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
@@ -15,14 +18,158 @@ router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
force: bool = Query(False, description="Force rebuild, ignoring cache"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
"""Return a deduplicated host inventory for the hunt.
|
||||
|
||||
Returns unique hosts with hostname, IPs, OS, logged-in users, and
|
||||
network connections derived from netstat/connection data.
|
||||
Returns instantly from cache if available (pre-built after upload or on startup).
|
||||
If cache is cold, triggers a background build and returns 202 so the
|
||||
frontend can poll /inventory-status and re-request when ready.
|
||||
"""
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
if result["stats"]["total_hosts"] == 0:
|
||||
return result
|
||||
return result
|
||||
# Force rebuild: invalidate cache, queue background job, return 202
|
||||
if force:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Rebuild queued"},
|
||||
)
|
||||
|
||||
# Try cache first
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is not None:
|
||||
logger.info(f"Serving cached host inventory for {hunt_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss: trigger background build instead of blocking for 90+ seconds
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
logger.info(f"Cache miss for {hunt_id}, triggering background build")
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Inventory is being built in the background"},
|
||||
)
|
||||
|
||||
|
||||
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||
return {
|
||||
"stats": inv.get("stats", {}),
|
||||
"top_hosts": [
|
||||
{
|
||||
"id": h.get("id"),
|
||||
"hostname": h.get("hostname"),
|
||||
"row_count": h.get("row_count", 0),
|
||||
"ip_count": len(h.get("ips", [])),
|
||||
"user_count": len(h.get("users", [])),
|
||||
}
|
||||
for h in top_hosts
|
||||
],
|
||||
"top_edges": top_edges,
|
||||
}
|
||||
|
||||
|
||||
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
|
||||
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||
|
||||
if node_id:
|
||||
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||
ids = {node_id}
|
||||
for c in rel_edges:
|
||||
ids.add(c.get("source"))
|
||||
ids.add(c.get("target"))
|
||||
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||
else:
|
||||
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||
allowed = {h.get("id") for h in rel_hosts}
|
||||
rel_edges = [
|
||||
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||
if c.get("source") in allowed and c.get("target") in allowed
|
||||
][:max_edges]
|
||||
|
||||
return {
|
||||
"hosts": rel_hosts,
|
||||
"connections": rel_edges,
|
||||
"stats": {
|
||||
**inv.get("stats", {}),
|
||||
"subgraph_hosts": len(rel_hosts),
|
||||
"subgraph_connections": len(rel_edges),
|
||||
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_inventory_summary(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
top_n: int = Query(20, ge=1, le=200),
|
||||
):
|
||||
"""Return a lightweight summary view for large hunts."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_summary(cached, top_n=top_n)
|
||||
|
||||
|
||||
@router.get("/subgraph")
|
||||
async def get_inventory_subgraph(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
node_id: str | None = Query(None, description="Optional focal node"),
|
||||
max_hosts: int = Query(200, ge=1, le=5000),
|
||||
max_edges: int = Query(1500, ge=1, le=20000),
|
||||
):
|
||||
"""Return a bounded subgraph for scale-safe rendering."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||
|
||||
|
||||
@router.get("/inventory-status")
|
||||
async def get_inventory_status(
|
||||
hunt_id: str = Query(..., description="Hunt ID to check"),
|
||||
):
|
||||
"""Check whether pre-computed host inventory is ready for a hunt.
|
||||
|
||||
Returns: { status: "ready" | "building" | "none" }
|
||||
"""
|
||||
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
|
||||
|
||||
|
||||
@router.post("/rebuild-inventory")
|
||||
async def trigger_rebuild(
|
||||
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
|
||||
):
|
||||
"""Trigger a background rebuild of the host inventory cache."""
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
|
||||
217
backend/app/api/routes/playbooks.py
Normal file
217
backend/app/api/routes/playbooks.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""API routes for investigation playbooks."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Playbook, PlaybookStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
|
||||
|
||||
|
||||
# -- Request / Response schemas ---
|
||||
|
||||
class StepCreate(BaseModel):
|
||||
title: str
|
||||
description: str | None = None
|
||||
step_type: str = "manual"
|
||||
target_route: str | None = None
|
||||
|
||||
class PlaybookCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
hunt_id: str | None = None
|
||||
is_template: bool = False
|
||||
steps: list[StepCreate] = []
|
||||
|
||||
class PlaybookUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
class StepUpdate(BaseModel):
|
||||
is_completed: bool | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
# -- Default investigation templates ---
|
||||
|
||||
DEFAULT_TEMPLATES = [
|
||||
{
|
||||
"name": "Standard Threat Hunt",
|
||||
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
|
||||
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
|
||||
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
|
||||
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
|
||||
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
|
||||
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Incident Response Triage",
|
||||
"description": "Fast-track workflow for active incident response.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
|
||||
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -- Routes ---
|
||||
|
||||
@router.get("")
|
||||
async def list_playbooks(
|
||||
include_templates: bool = True,
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Playbook)
|
||||
if hunt_id:
|
||||
q = q.where(Playbook.hunt_id == hunt_id)
|
||||
if not include_templates:
|
||||
q = q.where(Playbook.is_template == False)
|
||||
q = q.order_by(Playbook.created_at.desc())
|
||||
result = await db.execute(q.limit(100))
|
||||
playbooks = result.scalars().all()
|
||||
|
||||
return {"playbooks": [
|
||||
{
|
||||
"id": p.id, "name": p.name, "description": p.description,
|
||||
"is_template": p.is_template, "hunt_id": p.hunt_id,
|
||||
"status": p.status,
|
||||
"total_steps": len(p.steps),
|
||||
"completed_steps": sum(1 for s in p.steps if s.is_completed),
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
}
|
||||
for p in playbooks
|
||||
]}
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
async def get_templates():
|
||||
"""Return built-in investigation templates."""
|
||||
return {"templates": DEFAULT_TEMPLATES}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
|
||||
pb = Playbook(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
hunt_id=body.hunt_id,
|
||||
is_template=body.is_template,
|
||||
)
|
||||
db.add(pb)
|
||||
await db.flush()
|
||||
|
||||
created_steps = []
|
||||
for i, step in enumerate(body.steps):
|
||||
s = PlaybookStep(
|
||||
playbook_id=pb.id,
|
||||
order_index=i,
|
||||
title=step.title,
|
||||
description=step.description,
|
||||
step_type=step.step_type,
|
||||
target_route=step.target_route,
|
||||
)
|
||||
db.add(s)
|
||||
created_steps.append(s)
|
||||
|
||||
await db.flush()
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
|
||||
"steps": [
|
||||
{"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route, "is_completed": False}
|
||||
for s in created_steps
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{playbook_id}")
|
||||
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
|
||||
"status": pb.status,
|
||||
"created_at": pb.created_at.isoformat() if pb.created_at else None,
|
||||
"steps": [
|
||||
{
|
||||
"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route,
|
||||
"is_completed": s.is_completed,
|
||||
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
|
||||
"notes": s.notes,
|
||||
}
|
||||
for s in sorted(pb.steps, key=lambda x: x.order_index)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{playbook_id}")
|
||||
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
if body.name is not None:
|
||||
pb.name = body.name
|
||||
if body.description is not None:
|
||||
pb.description = body.description
|
||||
if body.status is not None:
|
||||
pb.status = body.status
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{playbook_id}")
|
||||
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
await db.delete(pb)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}")
|
||||
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
|
||||
step = result.scalar_one_or_none()
|
||||
if not step:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
if body.is_completed is not None:
|
||||
step.is_completed = body.is_completed
|
||||
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
|
||||
if body.notes is not None:
|
||||
step.notes = body.notes
|
||||
return {"status": "updated", "is_completed": step.is_completed}
|
||||
|
||||
164
backend/app/api/routes/saved_searches.py
Normal file
164
backend/app/api/routes/saved_searches.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""API routes for saved searches and bookmarked queries."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import SavedSearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
|
||||
|
||||
|
||||
class SearchCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
|
||||
query_params: dict
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
class SearchUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
query_params: dict | None = None
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_searches(
|
||||
search_type: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
|
||||
if search_type:
|
||||
q = q.where(SavedSearch.search_type == search_type)
|
||||
result = await db.execute(q.limit(100))
|
||||
searches = result.scalars().all()
|
||||
return {"searches": [
|
||||
{
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in searches
|
||||
]}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
|
||||
s = SavedSearch(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
search_type=body.search_type,
|
||||
query_params=body.query_params,
|
||||
threshold=body.threshold,
|
||||
)
|
||||
db.add(s)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": s.id, "name": s.name, "search_type": s.search_type,
|
||||
"query_params": s.query_params, "threshold": s.threshold,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{search_id}")
|
||||
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
return {
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{search_id}")
|
||||
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
if body.name is not None:
|
||||
s.name = body.name
|
||||
if body.description is not None:
|
||||
s.description = body.description
|
||||
if body.query_params is not None:
|
||||
s.query_params = body.query_params
|
||||
if body.threshold is not None:
|
||||
s.threshold = body.threshold
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{search_id}")
|
||||
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
await db.delete(s)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/{search_id}/run")
|
||||
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Execute a saved search and return results with delta from last run."""
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
|
||||
previous_count = s.last_result_count or 0
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
if s.search_type == "ioc_search":
|
||||
from app.db.models import EnrichmentResult
|
||||
ioc_value = s.query_params.get("ioc_value", "")
|
||||
if ioc_value:
|
||||
q = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value.contains(ioc_value)
|
||||
)
|
||||
res = await db.execute(q.limit(100))
|
||||
for er in res.scalars().all():
|
||||
results.append({
|
||||
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
|
||||
"source": er.source, "verdict": er.verdict,
|
||||
})
|
||||
count = len(results)
|
||||
|
||||
elif s.search_type == "keyword_scan":
|
||||
from app.db.models import KeywordTheme
|
||||
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
|
||||
themes = res.scalars().all()
|
||||
count = sum(len(t.keywords) for t in themes)
|
||||
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
|
||||
|
||||
# Update last run metadata
|
||||
s.last_run_at = datetime.now(timezone.utc)
|
||||
s.last_result_count = count
|
||||
|
||||
delta = count - previous_count
|
||||
|
||||
return {
|
||||
"search_id": s.id, "search_name": s.name,
|
||||
"search_type": s.search_type,
|
||||
"result_count": count,
|
||||
"previous_count": previous_count,
|
||||
"delta": delta,
|
||||
"results": results[:50],
|
||||
}
|
||||
184
backend/app/api/routes/stix_export.py
Normal file
184
backend/app/api/routes/stix_export.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""STIX 2.1 export endpoint.
|
||||
|
||||
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
|
||||
STIX 2.1 Bundle JSON download. No external dependencies required we
|
||||
build the JSON directly following the OASIS STIX 2.1 spec.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
|
||||
EnrichmentResult, HuntReport,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/export", tags=["export"])
|
||||
|
||||
STIX_SPEC_VERSION = "2.1"
|
||||
|
||||
|
||||
def _stix_id(stype: str) -> str:
|
||||
return f"{stype}--{uuid.uuid4()}"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
|
||||
def _build_identity(hunt_name: str) -> dict:
|
||||
return {
|
||||
"type": "identity",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("identity"),
|
||||
"created": _now_iso(),
|
||||
"modified": _now_iso(),
|
||||
"name": f"ThreatHunt - {hunt_name}",
|
||||
"identity_class": "system",
|
||||
}
|
||||
|
||||
|
||||
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
|
||||
pattern_map = {
|
||||
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
|
||||
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
|
||||
"domain": f"[domain-name:value = '{ioc_value}']",
|
||||
"url": f"[url:value = '{ioc_value}']",
|
||||
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
|
||||
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
|
||||
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
|
||||
"email": f"[email-addr:value = '{ioc_value}']",
|
||||
}
|
||||
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "indicator",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("indicator"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": f"{ioc_type}: {ioc_value}",
|
||||
"pattern": pattern,
|
||||
"pattern_type": "stix",
|
||||
"valid_from": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": [verdict or "suspicious"],
|
||||
}
|
||||
|
||||
|
||||
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "attack-pattern",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("attack-pattern"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": technique_id,
|
||||
"created_by_ref": identity_id,
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"external_id": technique_id,
|
||||
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "report",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("report"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": hyp.title,
|
||||
"description": hyp.description or "",
|
||||
"published": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": ["threat-hunt-hypothesis"],
|
||||
"object_refs": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stix/{hunt_id}")
|
||||
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
|
||||
# Fetch hunt
|
||||
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(404, "Hunt not found")
|
||||
|
||||
identity = _build_identity(hunt.name)
|
||||
objects: list[dict] = [identity]
|
||||
seen_techniques: set[str] = set()
|
||||
seen_iocs: set[str] = set()
|
||||
|
||||
# Gather IOCs from enrichment results for hunt's datasets
|
||||
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
|
||||
ds_ids = [r[0] for r in datasets_q.all()]
|
||||
|
||||
if ds_ids:
|
||||
enrichments = (await db.execute(
|
||||
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for e in enrichments:
|
||||
key = f"{e.ioc_type}:{e.ioc_value}"
|
||||
if key not in seen_iocs:
|
||||
seen_iocs.add(key)
|
||||
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
|
||||
|
||||
# Gather techniques from triage results
|
||||
triages = (await db.execute(
|
||||
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for t in triages:
|
||||
for tech in (t.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather techniques from host profiles
|
||||
profiles = (await db.execute(
|
||||
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for p in profiles:
|
||||
for tech in (p.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather hypotheses
|
||||
hypos = (await db.execute(
|
||||
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for h in hypos:
|
||||
objects.append(_hypothesis_to_report(h, identity["id"]))
|
||||
if h.mitre_technique and h.mitre_technique not in seen_techniques:
|
||||
seen_techniques.add(h.mitre_technique)
|
||||
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
|
||||
|
||||
bundle = {
|
||||
"type": "bundle",
|
||||
"id": _stix_id("bundle"),
|
||||
"objects": objects,
|
||||
}
|
||||
|
||||
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
|
||||
return Response(
|
||||
content=json.dumps(bundle, indent=2),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
128
backend/app/api/routes/timeline.py
Normal file
128
backend/app/api/routes/timeline.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""API routes for forensic timeline visualization."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Dataset, DatasetRow, Hunt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
|
||||
|
||||
|
||||
def _parse_timestamp(val: str | None) -> str | None:
|
||||
"""Try to parse a timestamp string, return ISO format or None."""
|
||||
if not val:
|
||||
return None
|
||||
val = str(val).strip()
|
||||
if not val:
|
||||
return None
|
||||
# Try common formats
|
||||
for fmt in [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
|
||||
]:
|
||||
try:
|
||||
return datetime.strptime(val, fmt).isoformat() + "Z"
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
# Columns likely to contain timestamps
|
||||
TIME_COLUMNS = {
|
||||
"timestamp", "time", "datetime", "date", "created", "modified",
|
||||
"eventtime", "event_time", "start_time", "end_time",
|
||||
"lastmodified", "last_modified", "created_at", "updated_at",
|
||||
"mtime", "atime", "ctime", "btime",
|
||||
"timecreated", "timegenerated", "sourcetime",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/hunt/{hunt_id}")
|
||||
async def get_hunt_timeline(
|
||||
hunt_id: str,
|
||||
limit: int = 2000,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a timeline of events across all datasets in a hunt."""
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
if not datasets:
|
||||
return {"hunt_id": hunt_id, "events": [], "datasets": []}
|
||||
|
||||
events = []
|
||||
dataset_info = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
dataset_info.append({
|
||||
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
|
||||
"row_count": ds.row_count,
|
||||
})
|
||||
|
||||
# Find time columns for this dataset
|
||||
schema = ds.column_schema or {}
|
||||
time_cols = []
|
||||
for col in (ds.normalized_columns or {}).values():
|
||||
if col.lower() in TIME_COLUMNS:
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
for col in schema:
|
||||
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
continue
|
||||
|
||||
# Fetch rows
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(limit // max(len(datasets), 1))
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
ts = None
|
||||
for tc in time_cols:
|
||||
ts = _parse_timestamp(data.get(tc))
|
||||
if ts:
|
||||
break
|
||||
if ts:
|
||||
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
|
||||
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
|
||||
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
|
||||
events.append({
|
||||
"timestamp": ts,
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": ds.name,
|
||||
"artifact_type": artifact_type,
|
||||
"row_index": r.row_index,
|
||||
"hostname": str(hostname)[:128],
|
||||
"process": str(process)[:128],
|
||||
"summary": str(summary)[:256],
|
||||
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
|
||||
})
|
||||
|
||||
# Sort by timestamp
|
||||
events.sort(key=lambda e: e["timestamp"])
|
||||
|
||||
return {
|
||||
"hunt_id": hunt_id,
|
||||
"hunt_name": hunt.name,
|
||||
"event_count": len(events),
|
||||
"datasets": dataset_info,
|
||||
"events": events[:limit],
|
||||
}
|
||||
Reference in New Issue
Block a user