mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
feat: host-centric network map, analysis dashboard, deduped inventory
- Rewrote NetworkMap to use deduplicated host inventory (163 hosts from 394K rows) - New host_inventory.py service: scans datasets, groups by FQDN/ClientId, extracts IPs/users/OS - New /api/network/host-inventory endpoint - Added AnalysisDashboard with 6 tabs (IOC, anomaly, host profile, query, triage, reports) - Added 16 analysis API endpoints with job queue and load balancer - Added 4 AI/analysis ORM models (ProcessingJob, AnalysisResult, HostProfile, IOCEntry) - Filters system accounts (DWM-*, UMFD-*, LOCAL/NETWORK SERVICE) - Infers OS from hostname patterns (W10-* -> Windows 10) - Canvas 2D force-directed graph with host/external-IP node types - Click popover shows hostname, FQDN, IPs, OS, users, datasets, connections
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
"""add processing_status and AI analysis tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 98ab619418bc
|
||||
Create Date: 2026-02-19 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, Sequence[str], None] = "98ab619418bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add columns to datasets table
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready"))
|
||||
batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True))
|
||||
batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True))
|
||||
batch_op.create_index("ix_datasets_status", ["processing_status"])
|
||||
|
||||
# Create triage_results table
|
||||
op.create_table(
|
||||
"triage_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_start", sa.Integer(), nullable=False),
|
||||
sa.Column("row_end", sa.Integer(), nullable=False),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("suspicious_indicators", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create host_profiles table
|
||||
op.create_table(
|
||||
"host_profiles",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("hostname", sa.String(256), nullable=False),
|
||||
sa.Column("fqdn", sa.String(512), nullable=True),
|
||||
sa.Column("client_id", sa.String(64), nullable=True),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"),
|
||||
sa.Column("artifact_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("timeline_summary", sa.Text(), nullable=True),
|
||||
sa.Column("suspicious_findings", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("llm_analysis", sa.Text(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create hunt_reports table
|
||||
op.create_table(
|
||||
"hunt_reports",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("exec_summary", sa.Text(), nullable=True),
|
||||
sa.Column("full_report", sa.Text(), nullable=True),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("recommendations", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_mapping", sa.JSON(), nullable=True),
|
||||
sa.Column("ioc_table", sa.JSON(), nullable=True),
|
||||
sa.Column("host_risk_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("models_used", sa.JSON(), nullable=True),
|
||||
sa.Column("generation_time_ms", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create anomaly_results table
|
||||
op.create_table(
|
||||
"anomaly_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("distance_from_centroid", sa.Float(), nullable=True),
|
||||
sa.Column("cluster_id", sa.Integer(), nullable=True),
|
||||
sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"),
|
||||
sa.Column("explanation", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("anomaly_results")
|
||||
op.drop_table("hunt_reports")
|
||||
op.drop_table("host_profiles")
|
||||
op.drop_table("triage_results")
|
||||
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.drop_index("ix_datasets_status")
|
||||
batch_op.drop_column("file_path")
|
||||
batch_op.drop_column("error_message")
|
||||
batch_op.drop_column("artifact_type")
|
||||
batch_op.drop_column("processing_status")
|
||||
402
backend/app/api/routes/analysis.py
Normal file
402
backend/app/api/routes/analysis.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
|
||||
host grouping, anomaly detection, data query (SSE), and job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import HostProfile, HuntReport, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# --- Response models ---
|
||||
|
||||
class TriageResultResponse(BaseModel):
|
||||
id: str
|
||||
dataset_id: str
|
||||
row_start: int
|
||||
row_end: int
|
||||
risk_score: float
|
||||
verdict: str
|
||||
findings: list | None = None
|
||||
suspicious_indicators: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
model_used: str | None = None
|
||||
node_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HostProfileResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
hostname: str
|
||||
fqdn: str | None = None
|
||||
risk_score: float
|
||||
risk_level: str
|
||||
artifact_summary: dict | None = None
|
||||
timeline_summary: str | None = None
|
||||
suspicious_findings: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
llm_analysis: str | None = None
|
||||
model_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HuntReportResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
status: str
|
||||
exec_summary: str | None = None
|
||||
full_report: str | None = None
|
||||
findings: list | None = None
|
||||
recommendations: list | None = None
|
||||
mitre_mapping: dict | None = None
|
||||
ioc_table: list | None = None
|
||||
host_risk_summary: list | None = None
|
||||
models_used: list | None = None
|
||||
generation_time_ms: int | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
mode: str = "quick" # quick or deep
|
||||
|
||||
|
||||
# --- Triage endpoints ---
|
||||
|
||||
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
|
||||
async def get_triage_results(
|
||||
dataset_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= min_risk)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/triage/{dataset_id}")
|
||||
async def trigger_triage(
|
||||
dataset_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.triage import triage_dataset
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "triage_started", "dataset_id": dataset_id}
|
||||
|
||||
|
||||
# --- Host profile endpoints ---
|
||||
|
||||
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
|
||||
async def get_host_profiles(
|
||||
hunt_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.where(HostProfile.risk_score >= min_risk)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}")
|
||||
async def trigger_all_profiles(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_all_hosts
|
||||
await profile_all_hosts(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}/{hostname}")
|
||||
async def trigger_single_profile(
|
||||
hunt_id: str,
|
||||
hostname: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_host
|
||||
await profile_host(hunt_id, hostname)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
|
||||
|
||||
|
||||
# --- Report endpoints ---
|
||||
|
||||
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
|
||||
async def list_reports(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
.order_by(HuntReport.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
|
||||
async def get_report(
|
||||
hunt_id: str,
|
||||
report_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.id == report_id)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return report
|
||||
|
||||
|
||||
@router.post("/reports/{hunt_id}/generate")
|
||||
async def trigger_report(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.report_generator import generate_report
|
||||
await generate_report(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "report_generation_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
# --- IOC extraction endpoints ---
|
||||
|
||||
@router.get("/iocs/{dataset_id}")
|
||||
async def extract_iocs(
|
||||
dataset_id: str,
|
||||
max_rows: int = Query(5000, ge=1, le=50000),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
|
||||
|
||||
|
||||
# --- Host grouping endpoints ---
|
||||
|
||||
@router.get("/hosts/{hunt_id}")
|
||||
async def get_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Group data by hostname across all datasets in a hunt."""
|
||||
from app.services.ioc_extractor import extract_host_groups
|
||||
groups = await extract_host_groups(hunt_id, db)
|
||||
return {"hunt_id": hunt_id, "hosts": groups}
|
||||
|
||||
|
||||
# --- Anomaly detection endpoints ---
|
||||
|
||||
@router.get("/anomalies/{dataset_id}")
|
||||
async def get_anomalies(
|
||||
dataset_id: str,
|
||||
outliers_only: bool = Query(False),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get anomaly detection results for a dataset."""
|
||||
from app.db.models import AnomalyResult
|
||||
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
|
||||
if outliers_only:
|
||||
stmt = stmt.where(AnomalyResult.is_outlier == True)
|
||||
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"dataset_id": r.dataset_id,
|
||||
"row_id": r.row_id,
|
||||
"anomaly_score": r.anomaly_score,
|
||||
"distance_from_centroid": r.distance_from_centroid,
|
||||
"cluster_id": r.cluster_id,
|
||||
"is_outlier": r.is_outlier,
|
||||
"explanation": r.explanation,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/anomalies/{dataset_id}")
|
||||
async def trigger_anomaly_detection(
|
||||
dataset_id: str,
|
||||
k: int = Query(3, ge=2, le=20),
|
||||
threshold: float = Query(0.35, ge=0.1, le=0.9),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Trigger embedding-based anomaly detection on a dataset."""
|
||||
async def _run():
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
|
||||
else:
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
|
||||
|
||||
|
||||
# --- Natural language data query (SSE streaming) ---
|
||||
|
||||
@router.post("/query/{dataset_id}")
|
||||
async def query_dataset_endpoint(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Ask a natural language question about a dataset.
|
||||
|
||||
Returns an SSE stream with token-by-token LLM response.
|
||||
Event types: status, metadata, token, error, done
|
||||
"""
|
||||
from app.services.data_query import query_dataset_stream
|
||||
|
||||
return StreamingResponse(
|
||||
query_dataset_stream(dataset_id, body.question, body.mode),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/query/{dataset_id}/sync")
|
||||
async def query_dataset_sync(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Non-streaming version of data query."""
|
||||
from app.services.data_query import query_dataset
|
||||
|
||||
try:
|
||||
answer = await query_dataset(dataset_id, body.question, body.mode)
|
||||
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# --- Job queue endpoints ---
|
||||
|
||||
@router.get("/jobs")
|
||||
async def list_jobs(
|
||||
status: str | None = Query(None),
|
||||
job_type: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List all tracked jobs."""
|
||||
from app.services.job_queue import job_queue, JobStatus, JobType
|
||||
|
||||
s = JobStatus(status) if status else None
|
||||
t = JobType(job_type) if job_type else None
|
||||
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
|
||||
stats = job_queue.get_stats()
|
||||
return {"jobs": jobs, "stats": stats}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(job_id: str):
|
||||
"""Get status of a specific job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
job = job_queue.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job.to_dict()
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}")
|
||||
async def cancel_job(job_id: str):
|
||||
"""Cancel a running or queued job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
if job_queue.cancel_job(job_id):
|
||||
return {"status": "cancelled", "job_id": job_id}
|
||||
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
|
||||
|
||||
|
||||
@router.post("/jobs/submit/{job_type}")
|
||||
async def submit_job(
|
||||
job_type: str,
|
||||
params: dict = {},
|
||||
):
|
||||
"""Submit a new job to the queue.
|
||||
|
||||
Job types: triage, host_profile, report, anomaly, query
|
||||
Params vary by type (e.g., dataset_id, hunt_id, question, mode).
|
||||
"""
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
|
||||
try:
|
||||
jt = JobType(job_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||
)
|
||||
|
||||
job = job_queue.submit(jt, **params)
|
||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||
|
||||
|
||||
# --- Load balancer status ---
|
||||
|
||||
@router.get("/lb/status")
|
||||
async def lb_status():
|
||||
"""Get load balancer status for both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
return lb.get_status()
|
||||
|
||||
|
||||
@router.post("/lb/check")
|
||||
async def lb_health_check():
|
||||
"""Force a health check of both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
await lb.check_health()
|
||||
return lb.get_status()
|
||||
28
backend/app/api/routes/network.py
Normal file
28
backend/app/api/routes/network.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Network topology API - host inventory endpoint."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.host_inventory import build_host_inventory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns unique hosts with hostname, IPs, OS, logged-in users, and
|
||||
network connections derived from netstat/connection data.
|
||||
"""
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
if result["stats"]["total_hosts"] == 0:
|
||||
return result
|
||||
return result
|
||||
@@ -3,6 +3,7 @@
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
@@ -12,12 +13,32 @@ from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||
|
||||
_engine_kwargs: dict = dict(
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
||||
_engine_kwargs["pool_size"] = 1
|
||||
_engine_kwargs["max_overflow"] = 0
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
@@ -51,4 +72,4 @@ async def init_db() -> None:
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
await engine.dispose()
|
||||
await engine.dispose()
|
||||
@@ -1,7 +1,7 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, and users.
|
||||
hypotheses, enrichment results, users, and AI analysis tables.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -32,8 +32,7 @@ def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
# ── Users ──────────────────────────────────────────────────────────────
|
||||
|
||||
# -- Users ---
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
@@ -42,17 +41,15 @@ class User(Base):
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
|
||||
|
||||
|
||||
# ── Hunts ──────────────────────────────────────────────────────────────
|
||||
|
||||
# -- Hunts ---
|
||||
|
||||
class Hunt(Base):
|
||||
__tablename__ = "hunts"
|
||||
@@ -60,7 +57,7 @@ class Hunt(Base):
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
|
||||
status: Mapped[str] = mapped_column(String(32), default="active")
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
@@ -69,15 +66,15 @@ class Hunt(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
|
||||
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
host_profiles: Mapped[list["HostProfile"]] = relationship(back_populates="hunt", lazy="noload")
|
||||
reports: Mapped[list["HuntReport"]] = relationship(back_populates="hunt", lazy="noload")
|
||||
|
||||
|
||||
# ── Datasets ───────────────────────────────────────────────────────────
|
||||
|
||||
# -- Datasets ---
|
||||
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
@@ -85,36 +82,44 @@ class Dataset(Base):
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
filename: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
|
||||
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
|
||||
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# New Phase 1-2 columns
|
||||
processing_status: Mapped[str] = mapped_column(String(20), default="ready")
|
||||
artifact_type: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
|
||||
rows: Mapped[list["DatasetRow"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
triage_results: Mapped[list["TriageResult"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_datasets_hunt", "hunt_id"),
|
||||
Index("ix_datasets_status", "processing_status"),
|
||||
)
|
||||
|
||||
|
||||
class DatasetRow(Base):
|
||||
"""Individual row from a CSV dataset, stored as JSON blob."""
|
||||
__tablename__ = "dataset_rows"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -125,7 +130,6 @@ class DatasetRow(Base):
|
||||
data: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# relationships
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(
|
||||
back_populates="row", lazy="noload"
|
||||
@@ -137,8 +141,7 @@ class DatasetRow(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Conversations ─────────────────────────────────────────────────────
|
||||
|
||||
# -- Conversations ---
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
@@ -156,7 +159,6 @@ class Conversation(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
|
||||
@@ -171,16 +173,15 @@ class Message(Base):
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
|
||||
__table_args__ = (
|
||||
@@ -188,8 +189,7 @@ class Message(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Annotations ───────────────────────────────────────────────────────
|
||||
|
||||
# -- Annotations ---
|
||||
|
||||
class Annotation(Base):
|
||||
__tablename__ = "annotations"
|
||||
@@ -205,19 +205,14 @@ class Annotation(Base):
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
severity: Mapped[str] = mapped_column(
|
||||
String(16), default="info"
|
||||
) # info | low | medium | high | critical
|
||||
tag: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), nullable=True
|
||||
) # suspicious | benign | needs-review
|
||||
severity: Mapped[str] = mapped_column(String(16), default="info")
|
||||
tag: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
|
||||
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
|
||||
|
||||
@@ -227,8 +222,7 @@ class Annotation(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Hypotheses ────────────────────────────────────────────────────────
|
||||
|
||||
# -- Hypotheses ---
|
||||
|
||||
class Hypothesis(Base):
|
||||
__tablename__ = "hypotheses"
|
||||
@@ -240,9 +234,7 @@ class Hypothesis(Base):
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), default="draft"
|
||||
) # draft | active | confirmed | rejected
|
||||
status: Mapped[str] = mapped_column(String(16), default="draft")
|
||||
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
@@ -250,7 +242,6 @@ class Hypothesis(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
|
||||
|
||||
__table_args__ = (
|
||||
@@ -258,21 +249,16 @@ class Hypothesis(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Results ────────────────────────────────────────────────
|
||||
|
||||
# -- Enrichment Results ---
|
||||
|
||||
class EnrichmentResult(Base):
|
||||
__tablename__ = "enrichment_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
|
||||
ioc_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False
|
||||
) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
|
||||
verdict: Mapped[Optional[str]] = mapped_column(
|
||||
String(16), nullable=True
|
||||
) # clean | suspicious | malicious | unknown
|
||||
ioc_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
verdict: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
@@ -287,28 +273,24 @@ class EnrichmentResult(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── AUP Keyword Themes & Keywords ────────────────────────────────────
|
||||
|
||||
# -- AUP Keyword Themes & Keywords ---
|
||||
|
||||
class KeywordTheme(Base):
|
||||
"""A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
|
||||
__tablename__ = "keyword_themes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
keywords: Mapped[list["Keyword"]] = relationship(
|
||||
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Keyword(Base):
|
||||
"""Individual keyword / pattern belonging to a theme."""
|
||||
__tablename__ = "keywords"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -319,10 +301,102 @@ class Keyword(Base):
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_keywords_theme", "theme_id"),
|
||||
Index("ix_keywords_value", "value"),
|
||||
)
|
||||
|
||||
|
||||
# -- AI Analysis Tables (Phase 2) ---
|
||||
|
||||
class TriageResult(Base):
|
||||
__tablename__ = "triage_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
row_start: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
row_end: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
verdict: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
suspicious_indicators: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="triage_results")
|
||||
|
||||
|
||||
class HostProfile(Base):
|
||||
__tablename__ = "host_profiles"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
hostname: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
fqdn: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
client_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
risk_level: Mapped[str] = mapped_column(String(20), default="unknown")
|
||||
artifact_summary: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
timeline_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
suspicious_findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
llm_analysis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
hunt: Mapped["Hunt"] = relationship(back_populates="host_profiles")
|
||||
|
||||
|
||||
class HuntReport(Base):
|
||||
__tablename__ = "hunt_reports"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
exec_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
full_report: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
recommendations: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_mapping: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_table: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
host_risk_summary: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
models_used: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
generation_time_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
hunt: Mapped["Hunt"] = relationship(back_populates="reports")
|
||||
|
||||
|
||||
class AnomalyResult(Base):
|
||||
__tablename__ = "anomaly_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
row_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
@@ -1,10 +1,12 @@
|
||||
"""ThreatHunt backend application.
|
||||
|
||||
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
|
||||
annotation/hypothesis routes. DB tables are auto-created on startup.
|
||||
annotation/hypothesis routes, analysis routes, network routes, job queue,
|
||||
load balancer. DB tables are auto-created on startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
@@ -21,6 +23,8 @@ from app.api.routes.correlation import router as correlation_router
|
||||
from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.network import router as network_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,17 +32,45 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown lifecycle."""
|
||||
logger.info("Starting ThreatHunt API …")
|
||||
logger.info("Starting ThreatHunt API ...")
|
||||
await init_db()
|
||||
logger.info("Database initialised")
|
||||
|
||||
# Ensure uploads directory exists
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
logger.info("Upload dir: %s", os.path.abspath(settings.UPLOAD_DIR))
|
||||
|
||||
# Seed default AUP keyword themes
|
||||
from app.db import async_session_factory
|
||||
from app.services.keyword_defaults import seed_defaults
|
||||
async with async_session_factory() as seed_db:
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
|
||||
# Start job queue (Phase 10)
|
||||
from app.services.job_queue import job_queue, register_all_handlers
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
|
||||
# Start load balancer health loop (Phase 10)
|
||||
from app.services.load_balancer import lb
|
||||
await lb.start_health_loop(interval=30.0)
|
||||
logger.info("Load balancer health loop started")
|
||||
|
||||
yield
|
||||
logger.info("Shutting down …")
|
||||
|
||||
logger.info("Shutting down ...")
|
||||
# Stop job queue
|
||||
from app.services.job_queue import job_queue as jq
|
||||
await jq.stop()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
# Stop load balancer
|
||||
from app.services.load_balancer import lb as _lb
|
||||
await _lb.stop_health_loop()
|
||||
logger.info("Load balancer stopped")
|
||||
|
||||
from app.agents.providers_v2 import cleanup_client
|
||||
from app.services.enrichment import enrichment_engine
|
||||
await cleanup_client()
|
||||
@@ -46,15 +78,13 @@ async def lifespan(app: FastAPI):
|
||||
await dispose_db()
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="ThreatHunt API",
|
||||
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
|
||||
version="0.3.0",
|
||||
version=settings.APP_VERSION,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
@@ -74,11 +104,12 @@ app.include_router(enrichment_router)
|
||||
app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(network_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
async def root():
|
||||
"""API health check."""
|
||||
return {
|
||||
"service": "ThreatHunt API",
|
||||
"version": settings.APP_VERSION,
|
||||
@@ -89,4 +120,4 @@ async def root():
|
||||
"roadrunner": settings.roadrunner_url,
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
}
|
||||
199
backend/app/services/anomaly_detector.py
Normal file
199
backend/app/services/anomaly_detector.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Embedding-based anomaly detection using Roadrunner's bge-m3 model.
|
||||
|
||||
Converts dataset rows to embeddings, clusters them, and flags outliers
|
||||
that deviate significantly from the cluster centroids. Uses cosine
|
||||
distance and simple k-means-like centroid computation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import AnomalyResult, Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBED_URL = f"{settings.roadrunner_url}/api/embed"
|
||||
EMBED_MODEL = "bge-m3"
|
||||
BATCH_SIZE = 32 # rows per embedding batch
|
||||
MAX_ROWS = 2000 # cap for anomaly detection
|
||||
|
||||
# --- math helpers (no numpy required) ---
|
||||
|
||||
def _dot(a: list[float], b: list[float]) -> float:
|
||||
return sum(x * y for x, y in zip(a, b))
|
||||
|
||||
|
||||
def _norm(v: list[float]) -> float:
|
||||
return math.sqrt(sum(x * x for x in v))
|
||||
|
||||
|
||||
def _cosine_distance(a: list[float], b: list[float]) -> float:
|
||||
na, nb = _norm(a), _norm(b)
|
||||
if na == 0 or nb == 0:
|
||||
return 1.0
|
||||
return 1.0 - _dot(a, b) / (na * nb)
|
||||
|
||||
|
||||
def _mean_vector(vectors: list[list[float]]) -> list[float]:
|
||||
if not vectors:
|
||||
return []
|
||||
dim = len(vectors[0])
|
||||
n = len(vectors)
|
||||
return [sum(v[i] for v in vectors) / n for i in range(dim)]
|
||||
|
||||
|
||||
def _row_to_text(data: dict) -> str:
|
||||
"""Flatten a row dict to a single string for embedding."""
|
||||
parts = []
|
||||
for k, v in data.items():
|
||||
sv = str(v).strip()
|
||||
if sv and sv.lower() not in ('none', 'null', ''):
|
||||
parts.append(f"{k}={sv}")
|
||||
return " | ".join(parts)[:2000] # cap length
|
||||
|
||||
|
||||
async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
||||
"""Get embeddings from Roadrunner's Ollama API."""
|
||||
resp = await client.post(
|
||||
EMBED_URL,
|
||||
json={"model": EMBED_MODEL, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Ollama returns {"embeddings": [[...], ...]}
|
||||
return data.get("embeddings", [])
|
||||
|
||||
|
||||
def _simple_cluster(
|
||||
embeddings: list[list[float]],
|
||||
k: int = 3,
|
||||
max_iter: int = 20,
|
||||
) -> tuple[list[int], list[list[float]]]:
|
||||
"""Simple k-means clustering (no numpy dependency).
|
||||
|
||||
Returns (assignments, centroids).
|
||||
"""
|
||||
n = len(embeddings)
|
||||
if n <= k:
|
||||
return list(range(n)), embeddings[:]
|
||||
|
||||
# Init centroids: evenly spaced indices
|
||||
step = max(n // k, 1)
|
||||
centroids = [embeddings[i * step % n] for i in range(k)]
|
||||
assignments = [0] * n
|
||||
|
||||
for _ in range(max_iter):
|
||||
# Assign to nearest centroid
|
||||
new_assignments = []
|
||||
for emb in embeddings:
|
||||
dists = [_cosine_distance(emb, c) for c in centroids]
|
||||
new_assignments.append(dists.index(min(dists)))
|
||||
|
||||
if new_assignments == assignments:
|
||||
break
|
||||
assignments = new_assignments
|
||||
|
||||
# Recompute centroids
|
||||
for ci in range(k):
|
||||
members = [embeddings[j] for j in range(n) if assignments[j] == ci]
|
||||
if members:
|
||||
centroids[ci] = _mean_vector(members)
|
||||
|
||||
return assignments, centroids
|
||||
|
||||
|
||||
async def detect_anomalies(
|
||||
dataset_id: str,
|
||||
k: int = 3,
|
||||
outlier_threshold: float = 0.35,
|
||||
) -> list[dict]:
|
||||
"""Run embedding-based anomaly detection on a dataset.
|
||||
|
||||
1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers.
|
||||
"""
|
||||
async with async_session_factory() as db:
|
||||
# Load rows
|
||||
result = await db.execute(
|
||||
select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(MAX_ROWS)
|
||||
)
|
||||
rows = result.all()
|
||||
if not rows:
|
||||
logger.info("No rows for anomaly detection in dataset %s", dataset_id)
|
||||
return []
|
||||
|
||||
row_ids = [r[0] for r in rows]
|
||||
row_indices = [r[1] for r in rows]
|
||||
texts = [_row_to_text(r[2]) for r in rows]
|
||||
|
||||
logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL)
|
||||
|
||||
# Embed in batches
|
||||
all_embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch = texts[i : i + BATCH_SIZE]
|
||||
try:
|
||||
embs = await _embed_batch(batch, client)
|
||||
all_embeddings.extend(embs)
|
||||
except Exception as e:
|
||||
logger.error("Embedding batch %d failed: %s", i, e)
|
||||
# Fill with zeros so indices stay aligned
|
||||
all_embeddings.extend([[0.0] * 1024] * len(batch))
|
||||
|
||||
if not all_embeddings or len(all_embeddings) != len(texts):
|
||||
logger.error("Embedding count mismatch")
|
||||
return []
|
||||
|
||||
# Cluster
|
||||
actual_k = min(k, len(all_embeddings))
|
||||
assignments, centroids = _simple_cluster(all_embeddings, k=actual_k)
|
||||
|
||||
# Compute distances from centroid
|
||||
anomalies: list[dict] = []
|
||||
for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)):
|
||||
dist = _cosine_distance(emb, centroids[cluster_id])
|
||||
is_outlier = dist > outlier_threshold
|
||||
anomalies.append({
|
||||
"row_id": row_ids[idx],
|
||||
"row_index": row_indices[idx],
|
||||
"anomaly_score": round(dist, 4),
|
||||
"distance_from_centroid": round(dist, 4),
|
||||
"cluster_id": cluster_id,
|
||||
"is_outlier": is_outlier,
|
||||
})
|
||||
|
||||
# Save to DB
|
||||
outlier_count = 0
|
||||
for a in anomalies:
|
||||
ar = AnomalyResult(
|
||||
dataset_id=dataset_id,
|
||||
row_id=a["row_id"],
|
||||
anomaly_score=a["anomaly_score"],
|
||||
distance_from_centroid=a["distance_from_centroid"],
|
||||
cluster_id=a["cluster_id"],
|
||||
is_outlier=a["is_outlier"],
|
||||
)
|
||||
db.add(ar)
|
||||
if a["is_outlier"]:
|
||||
outlier_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)",
|
||||
len(anomalies), outlier_count, outlier_threshold,
|
||||
)
|
||||
|
||||
return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True)
|
||||
81
backend/app/services/artifact_classifier.py
Normal file
81
backend/app/services/artifact_classifier.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Artifact classifier - identify Velociraptor artifact types from CSV headers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# (required_columns, artifact_type)
|
||||
FINGERPRINTS: list[tuple[set[str], str]] = [
|
||||
({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"),
|
||||
({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"),
|
||||
({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"),
|
||||
({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"),
|
||||
({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"),
|
||||
({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"),
|
||||
({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"),
|
||||
({"Key", "Type", "Value"}, "Windows.Registry.Finder"),
|
||||
({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"),
|
||||
({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"),
|
||||
({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"),
|
||||
({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"),
|
||||
({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"),
|
||||
({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"),
|
||||
({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"),
|
||||
({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"),
|
||||
({"os_info.hostname", "os_info.system"}, "Server.Information.Client"),
|
||||
({"ClientId", "os_info.fqdn"}, "Server.Information.Client"),
|
||||
({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"),
|
||||
({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"),
|
||||
({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"),
|
||||
({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"),
|
||||
({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"),
|
||||
({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"),
|
||||
]
|
||||
|
||||
VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"}
|
||||
|
||||
CATEGORY_MAP = {
|
||||
"Pslist": "process",
|
||||
"Netstat": "network",
|
||||
"Services": "persistence",
|
||||
"FileFinder": "filesystem",
|
||||
"Prefetch": "execution",
|
||||
"Registry": "persistence",
|
||||
"EvtxHunter": "eventlog",
|
||||
"EventLogs": "eventlog",
|
||||
"Autoruns": "persistence",
|
||||
"DNS": "network",
|
||||
"Hash": "filesystem",
|
||||
"TaskScheduler": "persistence",
|
||||
"Users": "account",
|
||||
"Client": "system",
|
||||
"WMI": "persistence",
|
||||
"ArpCache": "network",
|
||||
"BrowserHistory": "application",
|
||||
}
|
||||
|
||||
|
||||
def classify_artifact(columns: list[str]) -> str:
|
||||
col_set = set(columns)
|
||||
for required, artifact_type in FINGERPRINTS:
|
||||
if required.issubset(col_set):
|
||||
return artifact_type
|
||||
if VELOCIRAPTOR_META.intersection(col_set):
|
||||
return "Velociraptor.Unknown"
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_artifact_category(artifact_type: str) -> str:
|
||||
for key, category in CATEGORY_MAP.items():
|
||||
if key.lower() in artifact_type.lower():
|
||||
return category
|
||||
return "unknown"
|
||||
238
backend/app/services/data_query.py
Normal file
238
backend/app/services/data_query.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Natural-language data query service with SSE streaming.
|
||||
|
||||
Lets analysts ask questions about dataset rows in plain English.
|
||||
Routes to fast model (Roadrunner) for quick queries, heavy model (Wile)
|
||||
for deep analysis. Supports streaming via OllamaProvider.generate_stream().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum rows to include in context window
|
||||
MAX_CONTEXT_ROWS = 60
|
||||
MAX_ROW_TEXT_CHARS = 300
|
||||
|
||||
|
||||
def _rows_to_text(rows: list[dict], columns: list[str]) -> str:
|
||||
"""Convert dataset rows to a compact text table for the LLM context."""
|
||||
if not rows:
|
||||
return "(no rows)"
|
||||
# Header
|
||||
header = " | ".join(columns[:20]) # cap columns to avoid overflow
|
||||
lines = [header, "-" * min(len(header), 120)]
|
||||
for row in rows[:MAX_CONTEXT_ROWS]:
|
||||
vals = []
|
||||
for c in columns[:20]:
|
||||
v = str(row.get(c, ""))
|
||||
if len(v) > 80:
|
||||
v = v[:77] + "..."
|
||||
vals.append(v)
|
||||
line = " | ".join(vals)
|
||||
if len(line) > MAX_ROW_TEXT_CHARS:
|
||||
line = line[:MAX_ROW_TEXT_CHARS] + "..."
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt.
|
||||
You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.).
|
||||
|
||||
Your job:
|
||||
- Answer the analyst's question about this data accurately and concisely
|
||||
- Point out suspicious patterns, anomalies, or indicators of compromise
|
||||
- Reference MITRE ATT&CK techniques when relevant
|
||||
- Suggest follow-up queries or pivots
|
||||
- If you cannot answer from the data provided, say so clearly
|
||||
|
||||
Rules:
|
||||
- Be factual - only reference data you can see
|
||||
- Use forensic terminology appropriate for SOC/DFIR analysts
|
||||
- Format your answer with clear sections using markdown
|
||||
- If the data seems benign, say so - do not fabricate threats"""
|
||||
|
||||
|
||||
async def _load_dataset_context(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
sample_size: int = MAX_CONTEXT_ROWS,
|
||||
) -> tuple[dict, str, int]:
|
||||
"""Load dataset metadata + sample rows for context.
|
||||
|
||||
Returns (metadata_dict, rows_text, total_row_count).
|
||||
"""
|
||||
ds = await db.get(Dataset, dataset_id)
|
||||
if not ds:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
# Get total count
|
||||
count_q = await db.execute(
|
||||
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
total = count_q.scalar() or 0
|
||||
|
||||
# Sample rows - get first batch + some from the middle
|
||||
half = sample_size // 2
|
||||
result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(half)
|
||||
)
|
||||
first_rows = result.scalars().all()
|
||||
|
||||
# If dataset is large, also sample from the middle
|
||||
middle_rows = []
|
||||
if total > sample_size:
|
||||
mid_offset = total // 2
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(mid_offset)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
else:
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(half)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
|
||||
all_rows = first_rows + middle_rows
|
||||
row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows]
|
||||
|
||||
columns = list(ds.column_schema.keys()) if ds.column_schema else []
|
||||
if not columns and row_dicts:
|
||||
columns = list(row_dicts[0].keys())
|
||||
|
||||
rows_text = _rows_to_text(row_dicts, columns)
|
||||
|
||||
metadata = {
|
||||
"name": ds.name,
|
||||
"filename": ds.filename,
|
||||
"source_tool": ds.source_tool,
|
||||
"artifact_type": getattr(ds, "artifact_type", None),
|
||||
"row_count": total,
|
||||
"columns": columns[:30],
|
||||
"sample_rows_shown": len(all_rows),
|
||||
}
|
||||
return metadata, rows_text, total
|
||||
|
||||
|
||||
async def query_dataset(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> str:
|
||||
"""Non-streaming query: returns full answer text."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
)
|
||||
return result.get("response", "No response generated.")
|
||||
|
||||
|
||||
async def query_dataset_stream(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> AsyncIterator[str]:
|
||||
"""Streaming query: yields SSE-formatted events."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
# Send initial metadata event
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n"
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
model_name = settings.DEFAULT_HEAVY_MODEL
|
||||
node_name = "wile"
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
model_name = settings.DEFAULT_FAST_MODEL
|
||||
node_name = "roadrunner"
|
||||
|
||||
# Stream tokens
|
||||
token_count = 0
|
||||
try:
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
):
|
||||
token_count += 1
|
||||
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n"
|
||||
|
||||
|
||||
def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str:
|
||||
"""Construct the full prompt with data context."""
|
||||
parts = [
|
||||
f"## Dataset: {meta['name']}",
|
||||
f"- Source: {meta.get('source_tool', 'unknown')}",
|
||||
f"- Artifact type: {meta.get('artifact_type', 'unknown')}",
|
||||
f"- Total rows: {total}",
|
||||
f"- Columns: {', '.join(meta.get('columns', []))}",
|
||||
f"- Showing {meta['sample_rows_shown']} sample rows below",
|
||||
"",
|
||||
"## Sample Data",
|
||||
"```",
|
||||
rows_text,
|
||||
"```",
|
||||
"",
|
||||
f"## Analyst Question",
|
||||
question,
|
||||
]
|
||||
return "\n".join(parts)
|
||||
290
backend/app/services/host_inventory.py
Normal file
290
backend/app/services/host_inventory.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Host Inventory Service - builds a deduplicated host-centric network view.
|
||||
|
||||
Scans all datasets in a hunt to identify unique hosts, their IPs, OS,
|
||||
logged-in users, and network connections between them.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Column-name patterns (Velociraptor + generic forensic tools) ---
|
||||
|
||||
_HOST_ID_RE = re.compile(
|
||||
r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I)
|
||||
_FQDN_RE = re.compile(
|
||||
r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|'
|
||||
r'system_?name|machine_?name|nodename|workstation)$', re.I)
|
||||
_USERNAME_RE = re.compile(
|
||||
r'^(user|username|user_?name|logon_?name|account_?name|owner|'
|
||||
r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I)
|
||||
_LOCAL_IP_RE = re.compile(
|
||||
r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I)
|
||||
_REMOTE_IP_RE = re.compile(
|
||||
r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I)
|
||||
_REMOTE_PORT_RE = re.compile(
|
||||
r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I)
|
||||
_OS_RE = re.compile(
|
||||
r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I)
|
||||
_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$')
|
||||
|
||||
_IGNORE_IPS = frozenset({
|
||||
'0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null',
|
||||
})
|
||||
_SYSTEM_DOMAINS = frozenset({
|
||||
'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER',
|
||||
})
|
||||
_SYSTEM_USERS = frozenset({
|
||||
'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE',
|
||||
'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3',
|
||||
})
|
||||
|
||||
|
||||
def _is_valid_ip(v: str) -> bool:
|
||||
if not v or v in _IGNORE_IPS:
|
||||
return False
|
||||
return bool(_IP_VALID_RE.match(v))
|
||||
|
||||
|
||||
def _clean(v: Any) -> str:
|
||||
s = str(v or '').strip()
|
||||
return s if s and s not in ('-', 'None', 'null', '') else ''
|
||||
|
||||
|
||||
_SYSTEM_USER_RE = re.compile(
|
||||
r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I)
|
||||
|
||||
|
||||
def _extract_username(raw: str) -> str:
|
||||
"""Clean username, stripping domain prefixes and filtering system accounts."""
|
||||
if not raw:
|
||||
return ''
|
||||
name = raw.strip()
|
||||
if '\\' in name:
|
||||
domain, _, name = name.rpartition('\\')
|
||||
name = name.strip()
|
||||
if domain.strip().upper() in _SYSTEM_DOMAINS:
|
||||
if not name or _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
if _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
return name or ''
|
||||
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
return 'Windows 10'
|
||||
if 'W11-' in u or 'WIN11' in u:
|
||||
return 'Windows 11'
|
||||
if 'W7-' in u or 'WIN7' in u:
|
||||
return 'Windows 7'
|
||||
if 'SRV' in u or 'SERVER' in u or 'DC-' in u:
|
||||
return 'Windows Server'
|
||||
if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')):
|
||||
return 'Linux'
|
||||
if 'MAC' in u or 'DARWIN' in u:
|
||||
return 'macOS'
|
||||
return 'Windows'
|
||||
|
||||
|
||||
def _identify_columns(ds: Dataset) -> dict:
|
||||
norm = ds.normalized_columns or {}
|
||||
schema = ds.column_schema or {}
|
||||
raw_cols = list(schema.keys()) if schema else list(norm.keys())
|
||||
|
||||
result = {
|
||||
'host_id': [], 'fqdn': [], 'username': [],
|
||||
'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [],
|
||||
}
|
||||
|
||||
for col in raw_cols:
|
||||
canonical = (norm.get(col) or '').lower()
|
||||
lower = col.lower()
|
||||
|
||||
if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')):
|
||||
result['host_id'].append(col)
|
||||
|
||||
if _FQDN_RE.match(lower) or canonical == 'fqdn':
|
||||
result['fqdn'].append(col)
|
||||
|
||||
if _USERNAME_RE.match(lower) or canonical in ('username', 'user'):
|
||||
result['username'].append(col)
|
||||
|
||||
if _LOCAL_IP_RE.match(lower):
|
||||
result['local_ip'].append(col)
|
||||
elif _REMOTE_IP_RE.match(lower):
|
||||
result['remote_ip'].append(col)
|
||||
|
||||
if _REMOTE_PORT_RE.match(lower):
|
||||
result['remote_port'].append(col)
|
||||
|
||||
if _OS_RE.match(lower) or canonical == 'os':
|
||||
result['os'].append(col)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns dict with 'hosts', 'connections', and 'stats'.
|
||||
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
|
||||
"""
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
all_datasets = ds_result.scalars().all()
|
||||
|
||||
if not all_datasets:
|
||||
return {"hosts": [], "connections": [], "stats": {
|
||||
"total_hosts": 0, "total_datasets_scanned": 0,
|
||||
"total_rows_scanned": 0,
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
while True:
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
fqdn = _clean(data.get(c))
|
||||
if fqdn:
|
||||
break
|
||||
client_id = ''
|
||||
for c in cols['host_id']:
|
||||
client_id = _clean(data.get(c))
|
||||
if client_id:
|
||||
break
|
||||
|
||||
if not fqdn and not client_id:
|
||||
continue
|
||||
|
||||
host_key = fqdn or client_id
|
||||
|
||||
if host_key not in hosts:
|
||||
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
|
||||
hosts[host_key] = {
|
||||
'id': host_key,
|
||||
'hostname': short or client_id,
|
||||
'fqdn': fqdn,
|
||||
'client_id': client_id,
|
||||
'ips': set(),
|
||||
'os': '',
|
||||
'users': set(),
|
||||
'datasets': set(),
|
||||
'row_count': 0,
|
||||
}
|
||||
|
||||
h = hosts[host_key]
|
||||
h['datasets'].add(ds.name)
|
||||
h['row_count'] += 1
|
||||
if client_id and not h['client_id']:
|
||||
h['client_id'] = client_id
|
||||
|
||||
for c in cols['username']:
|
||||
u = _extract_username(_clean(data.get(c)))
|
||||
if u:
|
||||
h['users'].add(u)
|
||||
|
||||
for c in cols['local_ip']:
|
||||
ip = _clean(data.get(c))
|
||||
if _is_valid_ip(ip):
|
||||
h['ips'].add(ip)
|
||||
ip_to_host[ip] = host_key
|
||||
|
||||
for c in cols['os']:
|
||||
ov = _clean(data.get(c))
|
||||
if ov and not h['os']:
|
||||
h['os'] = ov
|
||||
|
||||
for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
|
||||
offset += batch_size
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
h['os'] = _infer_os(h['fqdn'])
|
||||
h['ips'] = sorted(h['ips'])
|
||||
h['users'] = sorted(h['users'])
|
||||
h['datasets'] = sorted(h['datasets'])
|
||||
|
||||
# Build connections, resolving IPs to host keys
|
||||
conn_list = []
|
||||
seen = set()
|
||||
for (src, dst_ip, dst_port), cnt in connections.items():
|
||||
if dst_ip in _IGNORE_IPS:
|
||||
continue
|
||||
dst_host = ip_to_host.get(dst_ip, '')
|
||||
if dst_host == src:
|
||||
continue
|
||||
key = tuple(sorted([src, dst_host or dst_ip]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
conn_list.append({
|
||||
'source': src,
|
||||
'target': dst_host or dst_ip,
|
||||
'target_ip': dst_ip,
|
||||
'port': dst_port,
|
||||
'count': cnt,
|
||||
})
|
||||
|
||||
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
|
||||
|
||||
return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
},
|
||||
}
|
||||
198
backend/app/services/host_profiler.py
Normal file
198
backend/app/services/host_profiler.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Host profiler - per-host deep threat analysis via Wile heavy models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= 3.0)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
.limit(10)
|
||||
)
|
||||
triages = result.scalars().all()
|
||||
if not triages:
|
||||
return "No significant triage findings."
|
||||
lines = []
|
||||
for t in triages:
|
||||
lines.append(
|
||||
f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} "
|
||||
f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_data: dict[str, list[dict]] = {}
|
||||
triage_parts: list[str] = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
matching = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
row_host = (
|
||||
data.get("hostname", "") or data.get("Fqdn", "")
|
||||
or data.get("ClientId", "") or data.get("client_id", "")
|
||||
)
|
||||
if hostname.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
elif fqdn and fqdn.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
|
||||
if matching:
|
||||
host_data[artifact_type] = matching[:50]
|
||||
triage_info = await _get_triage_summary(db, ds.id)
|
||||
triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}")
|
||||
|
||||
return {
|
||||
"artifacts": host_data,
|
||||
"triage_summary": "\n".join(triage_parts) or "No triage data.",
|
||||
"artifact_count": sum(len(v) for v in host_data.values()),
|
||||
}
|
||||
|
||||
|
||||
async def profile_host(
|
||||
hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None,
|
||||
) -> None:
|
||||
logger.info("Profiling host %s in hunt %s", hostname, hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
host_data = await _collect_host_data(db, hunt_id, hostname, fqdn)
|
||||
if host_data["artifact_count"] == 0:
|
||||
logger.info("No data found for host %s, skipping", hostname)
|
||||
return
|
||||
|
||||
system_prompt = (
|
||||
"You are a senior threat hunting analyst performing deep host analysis.\n"
|
||||
"You receive consolidated forensic artifacts and prior triage results for a single host.\n\n"
|
||||
"Provide a comprehensive host threat profile as JSON:\n"
|
||||
"- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n"
|
||||
"- risk_level: low/medium/high/critical\n"
|
||||
"- suspicious_findings: list of specific concerns\n"
|
||||
"- mitre_techniques: list of MITRE ATT&CK technique IDs\n"
|
||||
"- timeline_summary: brief timeline of suspicious activity\n"
|
||||
"- analysis: detailed narrative assessment\n\n"
|
||||
"Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n"
|
||||
"Respond with valid JSON only."
|
||||
)
|
||||
|
||||
artifact_summary = {}
|
||||
for art_type, rows in host_data["artifacts"].items():
|
||||
artifact_summary[art_type] = [
|
||||
{k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20]
|
||||
]
|
||||
|
||||
prompt = (
|
||||
f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n"
|
||||
f"## Prior Triage Results\n{host_data['triage_summary']}\n\n"
|
||||
f"## Artifact Data ({host_data['artifact_count']} total rows)\n"
|
||||
f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n"
|
||||
"Provide your comprehensive host threat profile as JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(
|
||||
WILE_URL,
|
||||
json={
|
||||
"model": HEAVY_MODEL,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 4096},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
llm_text = resp.json().get("response", "")
|
||||
|
||||
from app.services.triage import _parse_llm_response
|
||||
parsed = _parse_llm_response(llm_text)
|
||||
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id,
|
||||
hostname=hostname,
|
||||
fqdn=fqdn,
|
||||
client_id=client_id,
|
||||
risk_score=float(parsed.get("risk_score", 0.0)),
|
||||
risk_level=parsed.get("risk_level", "low"),
|
||||
artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()},
|
||||
timeline_summary=parsed.get("timeline_summary", ""),
|
||||
suspicious_findings=parsed.get("suspicious_findings", []),
|
||||
mitre_techniques=parsed.get("mitre_techniques", []),
|
||||
llm_analysis=parsed.get("analysis", llm_text[:5000]),
|
||||
model_used=HEAVY_MODEL,
|
||||
node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
llm_analysis=f"Error: {e}",
|
||||
model_used=HEAVY_MODEL, node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def profile_all_hosts(hunt_id: str) -> None:
|
||||
logger.info("Starting host profiling for hunt %s", hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
hostnames: dict[str, str | None] = {}
|
||||
for ds in datasets:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000)
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname")
|
||||
if host and str(host).strip():
|
||||
h = str(host).strip()
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
async def _bounded(hostname: str, fqdn: str | None):
|
||||
async with semaphore:
|
||||
await profile_host(hunt_id, hostname, fqdn)
|
||||
|
||||
tasks = [_bounded(h, f) for h, f in hostnames.items()]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames))
|
||||
210
backend/app/services/ioc_extractor.py
Normal file
210
backend/app/services/ioc_extractor.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""IOC extraction service extract indicators of compromise from dataset rows.
|
||||
|
||||
Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes,
|
||||
email addresses, URLs, and file paths that look suspicious.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns
|
||||
|
||||
_IPV4 = re.compile(
|
||||
r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b'
|
||||
)
|
||||
_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b')
|
||||
_DOMAIN = re.compile(
|
||||
r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)'
|
||||
r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|'
|
||||
r'online|site|club|win|work|download|stream|gdn|bid|review|racing|'
|
||||
r'loan|date|faith|accountant|cricket|science|trade|party|men)\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b')
|
||||
_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b')
|
||||
_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b')
|
||||
_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b')
|
||||
_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE)
|
||||
|
||||
# Private / reserved IPs to skip
|
||||
_PRIVATE_NETS = re.compile(
|
||||
r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)'
|
||||
)
|
||||
|
||||
PATTERNS = {
|
||||
'ipv4': _IPV4,
|
||||
'ipv6': _IPV6,
|
||||
'domain': _DOMAIN,
|
||||
'md5': _MD5,
|
||||
'sha1': _SHA1,
|
||||
'sha256': _SHA256,
|
||||
'email': _EMAIL,
|
||||
'url': _URL,
|
||||
}
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
return bool(_PRIVATE_NETS.match(ip))
|
||||
|
||||
|
||||
def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]:
|
||||
"""Extract all IOC types from a block of text."""
|
||||
result: dict[str, set[str]] = defaultdict(set)
|
||||
for ioc_type, pattern in PATTERNS.items():
|
||||
for match in pattern.findall(text):
|
||||
val = match.strip().lower() if ioc_type != 'url' else match.strip()
|
||||
# Filter private IPs
|
||||
if ioc_type == 'ipv4' and skip_private and _is_private_ip(val):
|
||||
continue
|
||||
# Filter hex strings that are too generic (< 32 chars not a hash)
|
||||
result[ioc_type].add(val)
|
||||
return result
|
||||
|
||||
|
||||
async def extract_iocs_from_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
max_rows: int = 5000,
|
||||
skip_private: bool = True,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Extract IOCs from all rows of a dataset.
|
||||
|
||||
Returns {ioc_type: [sorted unique values]}.
|
||||
"""
|
||||
# Load rows in batches
|
||||
all_iocs: dict[str, set[str]] = defaultdict(set)
|
||||
offset = 0
|
||||
batch_size = 500
|
||||
|
||||
while offset < max_rows:
|
||||
result = await db.execute(
|
||||
select(DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for data in rows:
|
||||
# Flatten all values to a single string for scanning
|
||||
text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data)
|
||||
batch_iocs = extract_iocs_from_text(text, skip_private)
|
||||
for ioc_type, values in batch_iocs.items():
|
||||
all_iocs[ioc_type].update(values)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Convert sets to sorted lists
|
||||
return {k: sorted(v) for k, v in all_iocs.items() if v}
|
||||
|
||||
|
||||
async def extract_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Group all data by hostname across datasets in a hunt.
|
||||
|
||||
Returns a list of host group dicts with dataset count, total rows,
|
||||
artifact types, and time range.
|
||||
"""
|
||||
# Get all datasets for this hunt
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
ds_list = result.scalars().all()
|
||||
if not ds_list:
|
||||
return []
|
||||
|
||||
# Known host columns (check normalized data first, then raw)
|
||||
HOST_COLS = [
|
||||
'hostname', 'host', 'computer_name', 'computername', 'system',
|
||||
'machine', 'device_name', 'devicename', 'endpoint',
|
||||
'ClientId', 'Fqdn', 'client_id', 'fqdn',
|
||||
]
|
||||
|
||||
hosts: dict[str, dict] = {}
|
||||
|
||||
for ds in ds_list:
|
||||
# Sample first few rows to find host column
|
||||
sample_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.limit(5)
|
||||
)
|
||||
samples = sample_result.all()
|
||||
if not samples:
|
||||
continue
|
||||
|
||||
# Find which host column exists
|
||||
host_col = None
|
||||
for row_data, norm_data in samples:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
for col in HOST_COLS:
|
||||
if col in check and check[col]:
|
||||
host_col = col
|
||||
break
|
||||
if host_col:
|
||||
break
|
||||
|
||||
if not host_col:
|
||||
continue
|
||||
|
||||
# Count rows per host in this dataset
|
||||
all_rows_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
)
|
||||
all_rows = all_rows_result.all()
|
||||
for row_data, norm_data in all_rows:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
host_val = check.get(host_col, '')
|
||||
if not host_val or not isinstance(host_val, str):
|
||||
continue
|
||||
host_val = host_val.strip()
|
||||
if not host_val:
|
||||
continue
|
||||
|
||||
if host_val not in hosts:
|
||||
hosts[host_val] = {
|
||||
'hostname': host_val,
|
||||
'dataset_ids': set(),
|
||||
'total_rows': 0,
|
||||
'artifact_types': set(),
|
||||
'first_seen': None,
|
||||
'last_seen': None,
|
||||
}
|
||||
hosts[host_val]['dataset_ids'].add(ds.id)
|
||||
hosts[host_val]['total_rows'] += 1
|
||||
if ds.artifact_type:
|
||||
hosts[host_val]['artifact_types'].add(ds.artifact_type)
|
||||
|
||||
# Convert to output format
|
||||
result_list = []
|
||||
for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True):
|
||||
result_list.append({
|
||||
'hostname': h['hostname'],
|
||||
'dataset_count': len(h['dataset_ids']),
|
||||
'total_rows': h['total_rows'],
|
||||
'artifact_types': sorted(h['artifact_types']),
|
||||
'first_seen': None, # TODO: extract from timestamp columns
|
||||
'last_seen': None,
|
||||
'risk_score': None, # TODO: link to host profiles
|
||||
})
|
||||
|
||||
return result_list
|
||||
316
backend/app/services/job_queue.py
Normal file
316
backend/app/services/job_queue.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
TRIAGE = "triage"
|
||||
HOST_PROFILE = "host_profile"
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
id: str
|
||||
job_type: JobType
|
||||
status: JobStatus = JobStatus.QUEUED
|
||||
progress: float = 0.0 # 0-100
|
||||
message: str = ""
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
started_at: float | None = None
|
||||
completed_at: float | None = None
|
||||
params: dict = field(default_factory=dict)
|
||||
_cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
end = self.completed_at or time.time()
|
||||
start = self.started_at or self.created_at
|
||||
return int((end - start) * 1000)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"job_type": self.job_type.value,
|
||||
"status": self.status.value,
|
||||
"progress": round(self.progress, 1),
|
||||
"message": self.message,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
"elapsed_ms": self.elapsed_ms,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def cancel(self):
|
||||
self._cancel_event.set()
|
||||
self.status = JobStatus.CANCELLED
|
||||
self.completed_at = time.time()
|
||||
self.message = "Cancelled by user"
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._max_workers = max_workers
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
return job
|
||||
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
return False
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
if job_type:
|
||||
jobs = [j for j in jobs if j.job_type == job_type]
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
return {
|
||||
"total": len(self._jobs),
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
}
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
job = self._jobs.get(job_id)
|
||||
if not job or job.is_cancelled:
|
||||
continue
|
||||
|
||||
handler = self._handlers.get(job.job_type)
|
||||
if not handler:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"No handler for {job.job_type.value}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(f"No handler for job type {job.job_type.value}")
|
||||
continue
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
job.message = "Running..."
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
result = await handler(job)
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.progress = 100.0
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
if hostname:
|
||||
job.message = f"Profiling host {hostname}"
|
||||
await profile_host(hunt_id, hostname)
|
||||
return {"hostname": hostname}
|
||||
else:
|
||||
job.message = f"Profiling all hosts in hunt {hunt_id}"
|
||||
await profile_all_hosts(hunt_id)
|
||||
return {"hunt_id": hunt_id}
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
report = await generate_report(hunt_id)
|
||||
return {"report_id": report.id if report else None}
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
threshold = job.params.get("threshold", 0.35)
|
||||
job.message = f"Detecting anomalies in dataset {dataset_id}"
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
mode = job.params.get("mode", "quick")
|
||||
job.message = f"Querying dataset {dataset_id}"
|
||||
answer = await query_dataset(dataset_id, question, mode)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
193
backend/app/services/load_balancer.py
Normal file
193
backend/app/services/load_balancer.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Smart load balancer for Wile & Roadrunner LLM nodes.
|
||||
|
||||
Tracks active jobs per node, health status, and routes new work
|
||||
to the least-busy healthy node. Periodically pings both nodes
|
||||
to maintain an up-to-date health map.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeId(str, Enum):
|
||||
WILE = "wile"
|
||||
ROADRUNNER = "roadrunner"
|
||||
|
||||
|
||||
class WorkloadTier(str, Enum):
|
||||
"""What kind of workload is this?"""
|
||||
HEAVY = "heavy" # 70B models, deep analysis, reports
|
||||
FAST = "fast" # 7-14B models, triage, quick queries
|
||||
EMBEDDING = "embed" # bge-m3 embeddings
|
||||
ANY = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStatus:
|
||||
node_id: NodeId
|
||||
url: str
|
||||
healthy: bool = True
|
||||
last_check: float = 0.0
|
||||
active_jobs: int = 0
|
||||
total_completed: int = 0
|
||||
total_errors: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
_latencies: list[float] = field(default_factory=list)
|
||||
|
||||
def record_completion(self, latency_ms: float):
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
self.total_completed += 1
|
||||
self._latencies.append(latency_ms)
|
||||
# Rolling average of last 50
|
||||
if len(self._latencies) > 50:
|
||||
self._latencies = self._latencies[-50:]
|
||||
self.avg_latency_ms = sum(self._latencies) / len(self._latencies)
|
||||
|
||||
def record_error(self):
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
self.total_errors += 1
|
||||
|
||||
def record_start(self):
|
||||
self.active_jobs += 1
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
"""Routes LLM work to the least-busy healthy node.
|
||||
|
||||
Node capabilities:
|
||||
- Wile: Heavy models (70B), code models (32B)
|
||||
- Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision
|
||||
"""
|
||||
|
||||
# Which nodes can handle which tiers
|
||||
TIER_NODES = {
|
||||
WorkloadTier.HEAVY: [NodeId.WILE],
|
||||
WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE],
|
||||
WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER],
|
||||
WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._nodes: dict[NodeId, NodeStatus] = {
|
||||
NodeId.WILE: NodeStatus(
|
||||
node_id=NodeId.WILE,
|
||||
url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}",
|
||||
),
|
||||
NodeId.ROADRUNNER: NodeStatus(
|
||||
node_id=NodeId.ROADRUNNER,
|
||||
url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}",
|
||||
),
|
||||
}
|
||||
self._lock = asyncio.Lock()
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start_health_loop(self, interval: float = 30.0):
|
||||
"""Start background health-check loop."""
|
||||
if self._health_task and not self._health_task.done():
|
||||
return
|
||||
self._health_task = asyncio.create_task(self._health_loop(interval))
|
||||
logger.info("Load balancer health loop started (%.0fs interval)", interval)
|
||||
|
||||
async def stop_health_loop(self):
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_task = None
|
||||
|
||||
async def _health_loop(self, interval: float):
|
||||
while True:
|
||||
try:
|
||||
await self.check_health()
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async def check_health(self):
|
||||
"""Ping both nodes and update status."""
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
for nid, status in self._nodes.items():
|
||||
try:
|
||||
resp = await client.get(f"{status.url}/api/tags")
|
||||
status.healthy = resp.status_code == 200
|
||||
except Exception:
|
||||
status.healthy = False
|
||||
status.last_check = time.time()
|
||||
logger.debug(
|
||||
f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} "
|
||||
f"(active={status.active_jobs})"
|
||||
)
|
||||
|
||||
def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
||||
"""Select the best node for a workload tier.
|
||||
|
||||
Strategy: among healthy nodes that support the tier,
|
||||
pick the one with fewest active jobs.
|
||||
Falls back to any node if none healthy.
|
||||
"""
|
||||
candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE])
|
||||
|
||||
# Filter to healthy candidates
|
||||
healthy = [
|
||||
nid for nid in candidates
|
||||
if self._nodes[nid].healthy
|
||||
]
|
||||
|
||||
if not healthy:
|
||||
logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate")
|
||||
healthy = candidates
|
||||
|
||||
# Pick least busy
|
||||
best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs)
|
||||
return best
|
||||
|
||||
def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
||||
"""Select node and mark a job started."""
|
||||
node = self.select_node(tier)
|
||||
self._nodes[node].record_start()
|
||||
logger.info(
|
||||
f"LB: dispatched {tier.value} -> {node.value} "
|
||||
f"(active={self._nodes[node].active_jobs})"
|
||||
)
|
||||
return node
|
||||
|
||||
def release(self, node: NodeId, latency_ms: float = 0, error: bool = False):
|
||||
"""Mark a job completed on a node."""
|
||||
status = self._nodes.get(node)
|
||||
if not status:
|
||||
return
|
||||
if error:
|
||||
status.record_error()
|
||||
else:
|
||||
status.record_completion(latency_ms)
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current load balancer status."""
|
||||
return {
|
||||
nid.value: {
|
||||
"healthy": s.healthy,
|
||||
"active_jobs": s.active_jobs,
|
||||
"total_completed": s.total_completed,
|
||||
"total_errors": s.total_errors,
|
||||
"avg_latency_ms": round(s.avg_latency_ms, 1),
|
||||
"last_check": s.last_check,
|
||||
}
|
||||
for nid, s in self._nodes.items()
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
lb = LoadBalancer()
|
||||
198
backend/app/services/report_generator.py
Normal file
198
backend/app/services/report_generator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Report generator - debate-powered hunt report generation using Wile + Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import (
|
||||
Dataset, HostProfile, HuntReport, TriageResult,
|
||||
)
|
||||
from app.services.triage import _parse_llm_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
|
||||
|
||||
async def _llm_call(url: str, model: str, system: str, prompt: str, timeout: float = 300.0) -> str:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 8192},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("response", "")
|
||||
|
||||
|
||||
async def _gather_evidence(db, hunt_id: str) -> dict:
|
||||
ds_result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = ds_result.scalars().all()
|
||||
|
||||
dataset_summary = []
|
||||
all_triage = []
|
||||
for ds in datasets:
|
||||
ds_info = {
|
||||
"name": ds.name,
|
||||
"artifact_type": getattr(ds, "artifact_type", "Unknown"),
|
||||
"row_count": ds.row_count or 0,
|
||||
}
|
||||
dataset_summary.append(ds_info)
|
||||
|
||||
triage_result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == ds.id)
|
||||
.where(TriageResult.risk_score >= 3.0)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
.limit(15)
|
||||
)
|
||||
for t in triage_result.scalars().all():
|
||||
all_triage.append({
|
||||
"dataset": ds.name,
|
||||
"artifact_type": ds_info["artifact_type"],
|
||||
"rows": f"{t.row_start}-{t.row_end}",
|
||||
"risk_score": t.risk_score,
|
||||
"verdict": t.verdict,
|
||||
"findings": t.findings[:5] if t.findings else [],
|
||||
"indicators": t.suspicious_indicators[:5] if t.suspicious_indicators else [],
|
||||
"mitre": t.mitre_techniques or [],
|
||||
})
|
||||
|
||||
profile_result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
)
|
||||
profiles = profile_result.scalars().all()
|
||||
host_summaries = []
|
||||
for p in profiles:
|
||||
host_summaries.append({
|
||||
"hostname": p.hostname,
|
||||
"risk_score": p.risk_score,
|
||||
"risk_level": p.risk_level,
|
||||
"findings": p.suspicious_findings[:5] if p.suspicious_findings else [],
|
||||
"mitre": p.mitre_techniques or [],
|
||||
"timeline": (p.timeline_summary or "")[:300],
|
||||
})
|
||||
|
||||
return {
|
||||
"datasets": dataset_summary,
|
||||
"triage_findings": all_triage[:30],
|
||||
"host_profiles": host_summaries,
|
||||
"total_datasets": len(datasets),
|
||||
"total_rows": sum(d["row_count"] for d in dataset_summary),
|
||||
"high_risk_hosts": len([h for h in host_summaries if h["risk_score"] >= 7.0]),
|
||||
}
|
||||
|
||||
|
||||
async def generate_report(hunt_id: str) -> None:
|
||||
logger.info("Generating report for hunt %s", hunt_id)
|
||||
start = time.monotonic()
|
||||
|
||||
async with async_session() as db:
|
||||
report = HuntReport(
|
||||
hunt_id=hunt_id,
|
||||
status="generating",
|
||||
models_used=[HEAVY_MODEL, FAST_MODEL],
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
report_id = report.id
|
||||
|
||||
try:
|
||||
evidence = await _gather_evidence(db, hunt_id)
|
||||
evidence_text = json.dumps(evidence, indent=1, default=str)[:12000]
|
||||
|
||||
# Phase 1: Wile initial analysis
|
||||
logger.info("Report phase 1: Wile initial analysis")
|
||||
phase1 = await _llm_call(
|
||||
WILE_URL, HEAVY_MODEL,
|
||||
system=(
|
||||
"You are a senior threat intelligence analyst writing a hunt report.\n"
|
||||
"Analyze all evidence and produce a structured threat assessment.\n"
|
||||
"Include: executive summary, detailed findings per host, MITRE mapping,\n"
|
||||
"IOC table, risk rankings, and actionable recommendations.\n"
|
||||
"Use markdown formatting. Be thorough and specific."
|
||||
),
|
||||
prompt=f"Hunt evidence:\n{evidence_text}\n\nProduce your initial threat assessment.",
|
||||
)
|
||||
|
||||
# Phase 2: Roadrunner critical review
|
||||
logger.info("Report phase 2: Roadrunner critical review")
|
||||
phase2 = await _llm_call(
|
||||
ROADRUNNER_URL, FAST_MODEL,
|
||||
system=(
|
||||
"You are a critical reviewer of threat hunt reports.\n"
|
||||
"Review the initial assessment and identify:\n"
|
||||
"- Missing correlations or overlooked indicators\n"
|
||||
"- False positive risks or overblown findings\n"
|
||||
"- Additional MITRE techniques that should be mapped\n"
|
||||
"- Gaps in recommendations\n"
|
||||
"Be specific and constructive. Respond in markdown."
|
||||
),
|
||||
prompt=f"Evidence:\n{evidence_text[:4000]}\n\nInitial Assessment:\n{phase1[:6000]}\n\nProvide your critical review.",
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
# Phase 3: Wile final synthesis
|
||||
logger.info("Report phase 3: Wile final synthesis")
|
||||
synthesis_prompt = (
|
||||
f"Original evidence:\n{evidence_text[:6000]}\n\n"
|
||||
f"Initial assessment:\n{phase1[:5000]}\n\n"
|
||||
f"Critical review:\n{phase2[:3000]}\n\n"
|
||||
"Produce the FINAL hunt report incorporating the review feedback.\n"
|
||||
"Return JSON with these keys:\n"
|
||||
"- executive_summary: 2-3 paragraph executive summary\n"
|
||||
"- findings: list of {title, severity, description, evidence, mitre_ids}\n"
|
||||
"- recommendations: list of {priority, action, rationale}\n"
|
||||
"- mitre_mapping: dict of technique_id -> {name, description, evidence}\n"
|
||||
"- ioc_table: list of {type, value, context, confidence}\n"
|
||||
"- host_risk_summary: list of {hostname, risk_score, risk_level, key_findings}\n"
|
||||
"Respond with valid JSON only."
|
||||
)
|
||||
phase3_text = await _llm_call(
|
||||
WILE_URL, HEAVY_MODEL,
|
||||
system="You are producing the final, definitive threat hunt report. Incorporate all feedback. Respond with valid JSON only.",
|
||||
prompt=synthesis_prompt,
|
||||
)
|
||||
|
||||
parsed = _parse_llm_response(phase3_text)
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
full_report = f"# Threat Hunt Report\n\n{phase1}\n\n---\n## Review Notes\n{phase2}\n\n---\n## Final Synthesis\n{phase3_text}"
|
||||
|
||||
report.status = "complete"
|
||||
report.exec_summary = parsed.get("executive_summary", phase1[:2000])
|
||||
report.full_report = full_report
|
||||
report.findings = parsed.get("findings", [])
|
||||
report.recommendations = parsed.get("recommendations", [])
|
||||
report.mitre_mapping = parsed.get("mitre_mapping", {})
|
||||
report.ioc_table = parsed.get("ioc_table", [])
|
||||
report.host_risk_summary = parsed.get("host_risk_summary", [])
|
||||
report.generation_time_ms = elapsed_ms
|
||||
await db.commit()
|
||||
|
||||
logger.info("Report %s complete in %dms", report_id, elapsed_ms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Report generation failed for hunt %s: %s", hunt_id, e)
|
||||
report.status = "error"
|
||||
report.exec_summary = f"Report generation failed: {e}"
|
||||
report.generation_time_ms = int((time.monotonic() - start) * 1000)
|
||||
await db.commit()
|
||||
170
backend/app/services/triage.py
Normal file
170
backend/app/services/triage.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import Dataset, DatasetRow, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
|
||||
ARTIFACT_FOCUS = {
|
||||
"Windows.System.Pslist": "Look for: suspicious parent-child, LOLBins, unsigned, injection indicators, abnormal paths.",
|
||||
"Windows.Network.Netstat": "Look for: C2 beaconing, unusual ports, connections to rare IPs, non-browser high-port listeners.",
|
||||
"Windows.System.Services": "Look for: services in temp dirs, misspelled names, unsigned ServiceDll, unusual start modes.",
|
||||
"Windows.Forensics.Prefetch": "Look for: recon tools, lateral movement tools, rarely-run executables with high run counts.",
|
||||
"Windows.EventLogs.EvtxHunter": "Look for: logon type 10/3 anomalies, service installs, PowerShell script blocks, clearing.",
|
||||
"Windows.Sys.Autoruns": "Look for: recently added entries, entries in temp/user dirs, encoded commands, suspicious DLLs.",
|
||||
"Windows.Registry.Finder": "Look for: run keys, image file execution options, hidden services, encoded payloads.",
|
||||
"Windows.Search.FileFinder": "Look for: files in unusual locations, recently modified system files, known tool names.",
|
||||
}
|
||||
|
||||
|
||||
def _parse_llm_response(text: str) -> dict:
|
||||
text = text.strip()
|
||||
fence = re.search(r"`(?:json)?\s*\n?(.*?)\n?\s*`", text, re.DOTALL)
|
||||
if fence:
|
||||
text = fence.group(1).strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
brace = text.find("{")
|
||||
bracket = text.rfind("}")
|
||||
if brace != -1 and bracket != -1 and bracket > brace:
|
||||
try:
|
||||
return json.loads(text[brace : bracket + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"raw_response": text[:3000]}
|
||||
|
||||
|
||||
async def triage_dataset(dataset_id: str) -> None:
|
||||
logger.info("Starting triage for dataset %s", dataset_id)
|
||||
|
||||
async with async_session() as db:
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset = ds_result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
logger.error("Dataset %s not found", dataset_id)
|
||||
return
|
||||
|
||||
artifact_type = getattr(dataset, "artifact_type", None) or "Unknown"
|
||||
focus = ARTIFACT_FOCUS.get(artifact_type, "Analyze for any suspicious indicators.")
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
total_rows = count_result.scalar() or 0
|
||||
|
||||
batch_size = settings.TRIAGE_BATCH_SIZE
|
||||
suspicious_count = 0
|
||||
offset = 0
|
||||
|
||||
while offset < total_rows:
|
||||
if suspicious_count >= settings.TRIAGE_MAX_SUSPICIOUS_ROWS:
|
||||
logger.info("Reached suspicious row cap for dataset %s", dataset_id)
|
||||
break
|
||||
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_number)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
batch_data = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
compact = {k: str(v)[:200] for k, v in data.items() if v}
|
||||
batch_data.append(compact)
|
||||
|
||||
system_prompt = f"""You are a cybersecurity triage analyst. Analyze this batch of {artifact_type} forensic data.
|
||||
{focus}
|
||||
|
||||
Return JSON with:
|
||||
- risk_score: 0.0 (benign) to 10.0 (critical threat)
|
||||
- verdict: "clean", "suspicious", "malicious", or "inconclusive"
|
||||
- findings: list of key observations
|
||||
- suspicious_indicators: list of specific IOCs or anomalies
|
||||
- mitre_techniques: list of MITRE ATT&CK IDs if applicable
|
||||
|
||||
Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."""
|
||||
|
||||
prompt = f"Rows {offset+1}-{offset+len(rows)} of {total_rows}:\n{json.dumps(batch_data, default=str)[:6000]}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(
|
||||
ROADRUNNER_URL,
|
||||
json={
|
||||
"model": DEFAULT_FAST_MODEL,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.2, "num_predict": 2048},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
llm_text = result.get("response", "")
|
||||
|
||||
parsed = _parse_llm_response(llm_text)
|
||||
risk = float(parsed.get("risk_score", 0.0))
|
||||
|
||||
triage = TriageResult(
|
||||
dataset_id=dataset_id,
|
||||
row_start=offset,
|
||||
row_end=offset + len(rows) - 1,
|
||||
risk_score=risk,
|
||||
verdict=parsed.get("verdict", "inconclusive"),
|
||||
findings=parsed.get("findings", []),
|
||||
suspicious_indicators=parsed.get("suspicious_indicators", []),
|
||||
mitre_techniques=parsed.get("mitre_techniques", []),
|
||||
model_used=DEFAULT_FAST_MODEL,
|
||||
node_used="roadrunner",
|
||||
)
|
||||
db.add(triage)
|
||||
await db.commit()
|
||||
|
||||
if risk >= settings.TRIAGE_ESCALATION_THRESHOLD:
|
||||
suspicious_count += len(rows)
|
||||
|
||||
logger.debug(
|
||||
"Triage batch %d-%d: risk=%.1f verdict=%s",
|
||||
offset, offset + len(rows) - 1, risk, triage.verdict,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Triage batch %d failed: %s", offset, e)
|
||||
triage = TriageResult(
|
||||
dataset_id=dataset_id,
|
||||
row_start=offset,
|
||||
row_end=offset + len(rows) - 1,
|
||||
risk_score=0.0,
|
||||
verdict="error",
|
||||
findings=[f"Error: {e}"],
|
||||
model_used=DEFAULT_FAST_MODEL,
|
||||
node_used="roadrunner",
|
||||
)
|
||||
db.add(triage)
|
||||
await db.commit()
|
||||
|
||||
offset += batch_size
|
||||
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
@@ -13,5 +13,5 @@ if __name__ == "__main__":
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True,
|
||||
)
|
||||
reload=False,
|
||||
)
|
||||
Reference in New Issue
Block a user