mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
chore: checkpoint all local changes
This commit is contained in:
@@ -1,371 +1,296 @@
|
||||
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
|
||||
host grouping, anomaly detection, data query (SSE), and job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import HostProfile, HuntReport, TriageResult
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.process_tree import (
|
||||
build_process_tree,
|
||||
build_storyline,
|
||||
compute_risk_scores,
|
||||
_fetch_rows,
|
||||
)
|
||||
from app.services.llm_analysis import (
|
||||
AnalysisRequest,
|
||||
AnalysisResult,
|
||||
run_llm_analysis,
|
||||
)
|
||||
from app.services.timeline import (
|
||||
build_timeline_bins,
|
||||
compute_field_stats,
|
||||
search_rows,
|
||||
)
|
||||
from app.services.mitre import (
|
||||
map_to_attack,
|
||||
build_knowledge_graph,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# --- Response models ---
|
||||
|
||||
class TriageResultResponse(BaseModel):
|
||||
id: str
|
||||
dataset_id: str
|
||||
row_start: int
|
||||
row_end: int
|
||||
risk_score: float
|
||||
verdict: str
|
||||
findings: list | None = None
|
||||
suspicious_indicators: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
model_used: str | None = None
|
||||
node_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HostProfileResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
class ProcessTreeResponse(BaseModel):
|
||||
trees: list[dict] = Field(default_factory=list)
|
||||
total_processes: int = 0
|
||||
|
||||
|
||||
class StorylineResponse(BaseModel):
|
||||
nodes: list[dict] = Field(default_factory=list)
|
||||
edges: list[dict] = Field(default_factory=list)
|
||||
summary: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RiskHostEntry(BaseModel):
|
||||
hostname: str
|
||||
fqdn: str | None = None
|
||||
risk_score: float
|
||||
risk_level: str
|
||||
artifact_summary: dict | None = None
|
||||
timeline_summary: str | None = None
|
||||
suspicious_findings: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
llm_analysis: str | None = None
|
||||
model_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
score: int = 0
|
||||
signals: list[str] = Field(default_factory=list)
|
||||
event_count: int = 0
|
||||
process_count: int = 0
|
||||
network_count: int = 0
|
||||
file_count: int = 0
|
||||
|
||||
|
||||
class HuntReportResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
status: str
|
||||
exec_summary: str | None = None
|
||||
full_report: str | None = None
|
||||
findings: list | None = None
|
||||
recommendations: list | None = None
|
||||
mitre_mapping: dict | None = None
|
||||
ioc_table: list | None = None
|
||||
host_risk_summary: list | None = None
|
||||
models_used: list | None = None
|
||||
generation_time_ms: int | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class RiskSummaryResponse(BaseModel):
|
||||
hosts: list[RiskHostEntry] = Field(default_factory=list)
|
||||
overall_score: int = 0
|
||||
total_events: int = 0
|
||||
severity_breakdown: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
mode: str = "quick" # quick or deep
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
# --- Triage endpoints ---
|
||||
|
||||
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
|
||||
async def get_triage_results(
|
||||
dataset_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
@router.get(
|
||||
"/process-tree",
|
||||
response_model=ProcessTreeResponse,
|
||||
summary="Build process tree from dataset rows",
|
||||
description=(
|
||||
"Extracts parent→child process relationships from dataset rows "
|
||||
"and returns a hierarchical forest of process nodes."
|
||||
),
|
||||
)
|
||||
async def get_process_tree(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= min_risk)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
"""Return process tree(s) for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
trees = await build_process_tree(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
# Count total processes recursively
|
||||
def _count(node: dict) -> int:
|
||||
return 1 + sum(_count(c) for c in node.get("children", []))
|
||||
|
||||
total = sum(_count(t) for t in trees)
|
||||
|
||||
return ProcessTreeResponse(trees=trees, total_processes=total)
|
||||
|
||||
|
||||
@router.post("/triage/{dataset_id}")
|
||||
async def trigger_triage(
|
||||
dataset_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.triage import triage_dataset
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "triage_started", "dataset_id": dataset_id}
|
||||
|
||||
|
||||
# --- Host profile endpoints ---
|
||||
|
||||
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
|
||||
async def get_host_profiles(
|
||||
hunt_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
@router.get(
|
||||
"/storyline",
|
||||
response_model=StorylineResponse,
|
||||
summary="Build CrowdStrike-style storyline attack graph",
|
||||
description=(
|
||||
"Creates a Cytoscape-compatible graph of events connected by "
|
||||
"process lineage (spawned) and temporal sequence within each host."
|
||||
),
|
||||
)
|
||||
async def get_storyline(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.where(HostProfile.risk_score >= min_risk)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
"""Return a storyline graph for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
result = await build_storyline(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
return StorylineResponse(**result)
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}")
|
||||
async def trigger_all_profiles(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_all_hosts
|
||||
await profile_all_hosts(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}/{hostname}")
|
||||
async def trigger_single_profile(
|
||||
hunt_id: str,
|
||||
hostname: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_host
|
||||
await profile_host(hunt_id, hostname)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
|
||||
|
||||
|
||||
# --- Report endpoints ---
|
||||
|
||||
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
|
||||
async def list_reports(
|
||||
hunt_id: str,
|
||||
@router.get(
|
||||
"/risk-summary",
|
||||
response_model=RiskSummaryResponse,
|
||||
summary="Compute risk scores per host",
|
||||
description=(
|
||||
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
|
||||
"credential dumping, lateral movement) and produces per-host risk scores."
|
||||
),
|
||||
)
|
||||
async def get_risk_summary(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
.order_by(HuntReport.created_at.desc())
|
||||
"""Return risk scores for all hosts in a hunt."""
|
||||
result = await compute_risk_scores(db, hunt_id=hunt_id)
|
||||
return RiskSummaryResponse(**result)
|
||||
|
||||
|
||||
# ── LLM Analysis ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm-analyze",
|
||||
response_model=AnalysisResult,
|
||||
summary="Run LLM-powered threat analysis on dataset",
|
||||
description=(
|
||||
"Loads dataset rows server-side, builds a summary, and sends to "
|
||||
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
|
||||
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
|
||||
),
|
||||
)
|
||||
async def llm_analyze(
|
||||
request: AnalysisRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Run LLM analysis on a dataset or hunt."""
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
limit=2000,
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found for analysis")
|
||||
|
||||
# Extract data dicts
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Get dataset name
|
||||
ds_name = "hunt datasets"
|
||||
if request.dataset_id:
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(request.dataset_id)
|
||||
if ds:
|
||||
ds_name = ds.name
|
||||
|
||||
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
|
||||
async def get_report(
|
||||
hunt_id: str,
|
||||
report_id: str,
|
||||
# ── Timeline ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/timeline",
|
||||
summary="Get event timeline histogram bins",
|
||||
)
|
||||
async def get_timeline(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
bins: int = Query(60, ge=10, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.id == report_id)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return report
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
|
||||
|
||||
|
||||
@router.post("/reports/{hunt_id}/generate")
|
||||
async def trigger_report(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.report_generator import generate_report
|
||||
await generate_report(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "report_generation_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
# --- IOC extraction endpoints ---
|
||||
|
||||
@router.get("/iocs/{dataset_id}")
|
||||
async def extract_iocs(
|
||||
dataset_id: str,
|
||||
max_rows: int = Query(5000, ge=1, le=50000),
|
||||
@router.get(
|
||||
"/field-stats",
|
||||
summary="Get per-field value distributions",
|
||||
)
|
||||
async def get_field_stats(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
fields: str | None = Query(None, description="Comma-separated field names"),
|
||||
top_n: int = Query(20, ge=5, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
|
||||
|
||||
|
||||
# --- Host grouping endpoints ---
|
||||
|
||||
@router.get("/hosts/{hunt_id}")
|
||||
async def get_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Group data by hostname across all datasets in a hunt."""
|
||||
from app.services.ioc_extractor import extract_host_groups
|
||||
groups = await extract_host_groups(hunt_id, db)
|
||||
return {"hunt_id": hunt_id, "hosts": groups}
|
||||
|
||||
|
||||
# --- Anomaly detection endpoints ---
|
||||
|
||||
@router.get("/anomalies/{dataset_id}")
|
||||
async def get_anomalies(
|
||||
dataset_id: str,
|
||||
outliers_only: bool = Query(False),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get anomaly detection results for a dataset."""
|
||||
from app.db.models import AnomalyResult
|
||||
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
|
||||
if outliers_only:
|
||||
stmt = stmt.where(AnomalyResult.is_outlier == True)
|
||||
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"dataset_id": r.dataset_id,
|
||||
"row_id": r.row_id,
|
||||
"anomaly_score": r.anomaly_score,
|
||||
"distance_from_centroid": r.distance_from_centroid,
|
||||
"cluster_id": r.cluster_id,
|
||||
"is_outlier": r.is_outlier,
|
||||
"explanation": r.explanation,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/anomalies/{dataset_id}")
|
||||
async def trigger_anomaly_detection(
|
||||
dataset_id: str,
|
||||
k: int = Query(3, ge=2, le=20),
|
||||
threshold: float = Query(0.35, ge=0.1, le=0.9),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Trigger embedding-based anomaly detection on a dataset."""
|
||||
async def _run():
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
|
||||
else:
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
|
||||
|
||||
|
||||
# --- Natural language data query (SSE streaming) ---
|
||||
|
||||
@router.post("/query/{dataset_id}")
|
||||
async def query_dataset_endpoint(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Ask a natural language question about a dataset.
|
||||
|
||||
Returns an SSE stream with token-by-token LLM response.
|
||||
Event types: status, metadata, token, error, done
|
||||
"""
|
||||
from app.services.data_query import query_dataset_stream
|
||||
|
||||
return StreamingResponse(
|
||||
query_dataset_stream(dataset_id, body.question, body.mode),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
field_list = [f.strip() for f in fields.split(",")] if fields else None
|
||||
return await compute_field_stats(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id,
|
||||
fields=field_list, top_n=top_n,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/query/{dataset_id}/sync")
|
||||
async def query_dataset_sync(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
class SearchRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
query: str = ""
|
||||
filters: dict[str, str] = Field(default_factory=dict)
|
||||
time_start: Optional[str] = None
|
||||
time_end: Optional[str] = None
|
||||
limit: int = 500
|
||||
offset: int = 0
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Search and filter dataset rows",
|
||||
)
|
||||
async def search_dataset_rows(
|
||||
request: SearchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Non-streaming version of data query."""
|
||||
from app.services.data_query import query_dataset
|
||||
|
||||
try:
|
||||
answer = await query_dataset(dataset_id, body.question, body.mode)
|
||||
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await search_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
time_start=request.time_start,
|
||||
time_end=request.time_end,
|
||||
limit=request.limit,
|
||||
offset=request.offset,
|
||||
)
|
||||
|
||||
|
||||
# --- Job queue endpoints ---
|
||||
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/jobs")
|
||||
async def list_jobs(
|
||||
status: str | None = Query(None),
|
||||
job_type: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
|
||||
@router.get(
|
||||
"/mitre-map",
|
||||
summary="Map dataset events to MITRE ATT&CK techniques",
|
||||
)
|
||||
async def get_mitre_map(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""List all tracked jobs."""
|
||||
from app.services.job_queue import job_queue, JobStatus, JobType
|
||||
|
||||
s = JobStatus(status) if status else None
|
||||
t = JobType(job_type) if job_type else None
|
||||
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
|
||||
stats = job_queue.get_stats()
|
||||
return {"jobs": jobs, "stats": stats}
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(job_id: str):
|
||||
"""Get status of a specific job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
job = job_queue.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job.to_dict()
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}")
|
||||
async def cancel_job(job_id: str):
|
||||
"""Cancel a running or queued job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
if job_queue.cancel_job(job_id):
|
||||
return {"status": "cancelled", "job_id": job_id}
|
||||
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
|
||||
|
||||
|
||||
@router.post("/jobs/submit/{job_type}")
|
||||
async def submit_job(
|
||||
job_type: str,
|
||||
params: dict = {},
|
||||
@router.get(
|
||||
"/knowledge-graph",
|
||||
summary="Build entity-technique knowledge graph",
|
||||
)
|
||||
async def get_knowledge_graph(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
<<<<<<< HEAD
|
||||
"""Submit a new job to the queue.
|
||||
|
||||
Job types: triage, host_profile, report, anomaly, query
|
||||
@@ -403,4 +328,9 @@ async def lb_health_check():
|
||||
"""Force a health check of both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
await lb.check_health()
|
||||
return lb.get_status()
|
||||
return lb.get_status()
|
||||
=======
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
Reference in New Issue
Block a user