diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py
new file mode 100644
index 0000000..700524d
--- /dev/null
+++ b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py
@@ -0,0 +1,112 @@
+"""add processing_status and AI analysis tables
+
+Revision ID: a1b2c3d4e5f6
+Revises: 98ab619418bc
+Create Date: 2026-02-19 18:00:00.000000
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+revision: str = "a1b2c3d4e5f6"
+down_revision: Union[str, Sequence[str], None] = "98ab619418bc"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # Add columns to datasets table
+ with op.batch_alter_table("datasets") as batch_op:
+ batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready"))
+ batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True))
+ batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True))
+ batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True))
+ batch_op.create_index("ix_datasets_status", ["processing_status"])
+
+ # Create triage_results table
+ op.create_table(
+ "triage_results",
+ sa.Column("id", sa.String(32), primary_key=True),
+ sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
+ sa.Column("row_start", sa.Integer(), nullable=False),
+ sa.Column("row_end", sa.Integer(), nullable=False),
+ sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
+ sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"),
+ sa.Column("findings", sa.JSON(), nullable=True),
+ sa.Column("suspicious_indicators", sa.JSON(), nullable=True),
+ sa.Column("mitre_techniques", sa.JSON(), nullable=True),
+ sa.Column("model_used", sa.String(128), nullable=True),
+ sa.Column("node_used", sa.String(64), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ )
+
+ # Create host_profiles table
+ op.create_table(
+ "host_profiles",
+ sa.Column("id", sa.String(32), primary_key=True),
+ sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
+ sa.Column("hostname", sa.String(256), nullable=False),
+ sa.Column("fqdn", sa.String(512), nullable=True),
+ sa.Column("client_id", sa.String(64), nullable=True),
+ sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
+ sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"),
+ sa.Column("artifact_summary", sa.JSON(), nullable=True),
+ sa.Column("timeline_summary", sa.Text(), nullable=True),
+ sa.Column("suspicious_findings", sa.JSON(), nullable=True),
+ sa.Column("mitre_techniques", sa.JSON(), nullable=True),
+ sa.Column("llm_analysis", sa.Text(), nullable=True),
+ sa.Column("model_used", sa.String(128), nullable=True),
+ sa.Column("node_used", sa.String(64), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ )
+
+ # Create hunt_reports table
+ op.create_table(
+ "hunt_reports",
+ sa.Column("id", sa.String(32), primary_key=True),
+ sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
+ sa.Column("status", sa.String(20), nullable=False, server_default="pending"),
+ sa.Column("exec_summary", sa.Text(), nullable=True),
+ sa.Column("full_report", sa.Text(), nullable=True),
+ sa.Column("findings", sa.JSON(), nullable=True),
+ sa.Column("recommendations", sa.JSON(), nullable=True),
+ sa.Column("mitre_mapping", sa.JSON(), nullable=True),
+ sa.Column("ioc_table", sa.JSON(), nullable=True),
+ sa.Column("host_risk_summary", sa.JSON(), nullable=True),
+ sa.Column("models_used", sa.JSON(), nullable=True),
+ sa.Column("generation_time_ms", sa.Integer(), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ )
+
+ # Create anomaly_results table
+ op.create_table(
+ "anomaly_results",
+ sa.Column("id", sa.String(32), primary_key=True),
+ sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
+ sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True),
+ sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"),
+ sa.Column("distance_from_centroid", sa.Float(), nullable=True),
+ sa.Column("cluster_id", sa.Integer(), nullable=True),
+ sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"),
+ sa.Column("explanation", sa.Text(), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
+ )
+
+
+def downgrade() -> None:
+ op.drop_table("anomaly_results")
+ op.drop_table("hunt_reports")
+ op.drop_table("host_profiles")
+ op.drop_table("triage_results")
+
+ with op.batch_alter_table("datasets") as batch_op:
+ batch_op.drop_index("ix_datasets_status")
+ batch_op.drop_column("file_path")
+ batch_op.drop_column("error_message")
+ batch_op.drop_column("artifact_type")
+ batch_op.drop_column("processing_status")
\ No newline at end of file
diff --git a/backend/app/api/routes/analysis.py b/backend/app/api/routes/analysis.py
new file mode 100644
index 0000000..ff6f7fc
--- /dev/null
+++ b/backend/app/api/routes/analysis.py
@@ -0,0 +1,402 @@
+"""Analysis API routes - triage, host profiles, reports, IOC extraction,
+host grouping, anomaly detection, data query (SSE), and job management."""
+
+from __future__ import annotations
+
+import logging
+from typing import Optional
+
+from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import get_db
+from app.db.models import HostProfile, HuntReport, TriageResult
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/analysis", tags=["analysis"])
+
+
+# --- Response models ---
+
+class TriageResultResponse(BaseModel):
+ id: str
+ dataset_id: str
+ row_start: int
+ row_end: int
+ risk_score: float
+ verdict: str
+ findings: list | None = None
+ suspicious_indicators: list | None = None
+ mitre_techniques: list | None = None
+ model_used: str | None = None
+ node_used: str | None = None
+
+ class Config:
+ from_attributes = True
+
+
+class HostProfileResponse(BaseModel):
+ id: str
+ hunt_id: str
+ hostname: str
+ fqdn: str | None = None
+ risk_score: float
+ risk_level: str
+ artifact_summary: dict | None = None
+ timeline_summary: str | None = None
+ suspicious_findings: list | None = None
+ mitre_techniques: list | None = None
+ llm_analysis: str | None = None
+ model_used: str | None = None
+
+ class Config:
+ from_attributes = True
+
+
+class HuntReportResponse(BaseModel):
+ id: str
+ hunt_id: str
+ status: str
+ exec_summary: str | None = None
+ full_report: str | None = None
+ findings: list | None = None
+ recommendations: list | None = None
+ mitre_mapping: dict | None = None
+ ioc_table: list | None = None
+ host_risk_summary: list | None = None
+ models_used: list | None = None
+ generation_time_ms: int | None = None
+
+ class Config:
+ from_attributes = True
+
+
+class QueryRequest(BaseModel):
+ question: str
+ mode: str = "quick" # quick or deep
+
+
+# --- Triage endpoints ---
+
+@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
+async def get_triage_results(
+ dataset_id: str,
+ min_risk: float = Query(0.0, ge=0.0, le=10.0),
+ db: AsyncSession = Depends(get_db),
+):
+ result = await db.execute(
+ select(TriageResult)
+ .where(TriageResult.dataset_id == dataset_id)
+ .where(TriageResult.risk_score >= min_risk)
+ .order_by(TriageResult.risk_score.desc())
+ )
+ return result.scalars().all()
+
+
+@router.post("/triage/{dataset_id}")
+async def trigger_triage(
+ dataset_id: str,
+ background_tasks: BackgroundTasks,
+):
+ async def _run():
+ from app.services.triage import triage_dataset
+ await triage_dataset(dataset_id)
+
+ background_tasks.add_task(_run)
+ return {"status": "triage_started", "dataset_id": dataset_id}
+
+
+# --- Host profile endpoints ---
+
+@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
+async def get_host_profiles(
+ hunt_id: str,
+ min_risk: float = Query(0.0, ge=0.0, le=10.0),
+ db: AsyncSession = Depends(get_db),
+):
+ result = await db.execute(
+ select(HostProfile)
+ .where(HostProfile.hunt_id == hunt_id)
+ .where(HostProfile.risk_score >= min_risk)
+ .order_by(HostProfile.risk_score.desc())
+ )
+ return result.scalars().all()
+
+
+@router.post("/profiles/{hunt_id}")
+async def trigger_all_profiles(
+ hunt_id: str,
+ background_tasks: BackgroundTasks,
+):
+ async def _run():
+ from app.services.host_profiler import profile_all_hosts
+ await profile_all_hosts(hunt_id)
+
+ background_tasks.add_task(_run)
+ return {"status": "profiling_started", "hunt_id": hunt_id}
+
+
+@router.post("/profiles/{hunt_id}/{hostname}")
+async def trigger_single_profile(
+ hunt_id: str,
+ hostname: str,
+ background_tasks: BackgroundTasks,
+):
+ async def _run():
+ from app.services.host_profiler import profile_host
+ await profile_host(hunt_id, hostname)
+
+ background_tasks.add_task(_run)
+ return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
+
+
+# --- Report endpoints ---
+
+@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
+async def list_reports(
+ hunt_id: str,
+ db: AsyncSession = Depends(get_db),
+):
+ result = await db.execute(
+ select(HuntReport)
+ .where(HuntReport.hunt_id == hunt_id)
+ .order_by(HuntReport.created_at.desc())
+ )
+ return result.scalars().all()
+
+
+@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
+async def get_report(
+ hunt_id: str,
+ report_id: str,
+ db: AsyncSession = Depends(get_db),
+):
+ result = await db.execute(
+ select(HuntReport)
+ .where(HuntReport.id == report_id)
+ .where(HuntReport.hunt_id == hunt_id)
+ )
+ report = result.scalar_one_or_none()
+ if not report:
+ raise HTTPException(status_code=404, detail="Report not found")
+ return report
+
+
+@router.post("/reports/{hunt_id}/generate")
+async def trigger_report(
+ hunt_id: str,
+ background_tasks: BackgroundTasks,
+):
+ async def _run():
+ from app.services.report_generator import generate_report
+ await generate_report(hunt_id)
+
+ background_tasks.add_task(_run)
+ return {"status": "report_generation_started", "hunt_id": hunt_id}
+
+
+# --- IOC extraction endpoints ---
+
+@router.get("/iocs/{dataset_id}")
+async def extract_iocs(
+ dataset_id: str,
+ max_rows: int = Query(5000, ge=1, le=50000),
+ db: AsyncSession = Depends(get_db),
+):
+ """Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
+ from app.services.ioc_extractor import extract_iocs_from_dataset
+ iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
+ total = sum(len(v) for v in iocs.values())
+ return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
+
+
+# --- Host grouping endpoints ---
+
+@router.get("/hosts/{hunt_id}")
+async def get_host_groups(
+ hunt_id: str,
+ db: AsyncSession = Depends(get_db),
+):
+ """Group data by hostname across all datasets in a hunt."""
+ from app.services.ioc_extractor import extract_host_groups
+ groups = await extract_host_groups(hunt_id, db)
+ return {"hunt_id": hunt_id, "hosts": groups}
+
+
+# --- Anomaly detection endpoints ---
+
+@router.get("/anomalies/{dataset_id}")
+async def get_anomalies(
+ dataset_id: str,
+ outliers_only: bool = Query(False),
+ db: AsyncSession = Depends(get_db),
+):
+ """Get anomaly detection results for a dataset."""
+ from app.db.models import AnomalyResult
+ stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
+ if outliers_only:
+ stmt = stmt.where(AnomalyResult.is_outlier == True)
+ stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
+ result = await db.execute(stmt)
+ rows = result.scalars().all()
+ return [
+ {
+ "id": r.id,
+ "dataset_id": r.dataset_id,
+ "row_id": r.row_id,
+ "anomaly_score": r.anomaly_score,
+ "distance_from_centroid": r.distance_from_centroid,
+ "cluster_id": r.cluster_id,
+ "is_outlier": r.is_outlier,
+ "explanation": r.explanation,
+ }
+ for r in rows
+ ]
+
+
+@router.post("/anomalies/{dataset_id}")
+async def trigger_anomaly_detection(
+ dataset_id: str,
+ k: int = Query(3, ge=2, le=20),
+ threshold: float = Query(0.35, ge=0.1, le=0.9),
+ background_tasks: BackgroundTasks = None,
+):
+ """Trigger embedding-based anomaly detection on a dataset."""
+ async def _run():
+ from app.services.anomaly_detector import detect_anomalies
+ await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
+
+ if background_tasks:
+ background_tasks.add_task(_run)
+ return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
+ else:
+ from app.services.anomaly_detector import detect_anomalies
+ results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
+ return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
+
+
+# --- Natural language data query (SSE streaming) ---
+
+@router.post("/query/{dataset_id}")
+async def query_dataset_endpoint(
+ dataset_id: str,
+ body: QueryRequest,
+):
+ """Ask a natural language question about a dataset.
+
+ Returns an SSE stream with token-by-token LLM response.
+ Event types: status, metadata, token, error, done
+ """
+ from app.services.data_query import query_dataset_stream
+
+ return StreamingResponse(
+ query_dataset_stream(dataset_id, body.question, body.mode),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no",
+ },
+ )
+
+
+@router.post("/query/{dataset_id}/sync")
+async def query_dataset_sync(
+ dataset_id: str,
+ body: QueryRequest,
+):
+ """Non-streaming version of data query."""
+ from app.services.data_query import query_dataset
+
+ try:
+ answer = await query_dataset(dataset_id, body.question, body.mode)
+ return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
+ except ValueError as e:
+ raise HTTPException(status_code=404, detail=str(e))
+ except Exception as e:
+ logger.error(f"Query failed: {e}", exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+# --- Job queue endpoints ---
+
+@router.get("/jobs")
+async def list_jobs(
+ status: str | None = Query(None),
+ job_type: str | None = Query(None),
+ limit: int = Query(50, ge=1, le=200),
+):
+ """List all tracked jobs."""
+ from app.services.job_queue import job_queue, JobStatus, JobType
+
+ s = JobStatus(status) if status else None
+ t = JobType(job_type) if job_type else None
+ jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
+ stats = job_queue.get_stats()
+ return {"jobs": jobs, "stats": stats}
+
+
+@router.get("/jobs/{job_id}")
+async def get_job(job_id: str):
+ """Get status of a specific job."""
+ from app.services.job_queue import job_queue
+
+ job = job_queue.get_job(job_id)
+ if not job:
+ raise HTTPException(status_code=404, detail="Job not found")
+ return job.to_dict()
+
+
+@router.delete("/jobs/{job_id}")
+async def cancel_job(job_id: str):
+ """Cancel a running or queued job."""
+ from app.services.job_queue import job_queue
+
+ if job_queue.cancel_job(job_id):
+ return {"status": "cancelled", "job_id": job_id}
+ raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
+
+
+@router.post("/jobs/submit/{job_type}")
+async def submit_job(
+ job_type: str,
+ params: dict = {},
+):
+ """Submit a new job to the queue.
+
+ Job types: triage, host_profile, report, anomaly, query
+ Params vary by type (e.g., dataset_id, hunt_id, question, mode).
+ """
+ from app.services.job_queue import job_queue, JobType
+
+ try:
+ jt = JobType(job_type)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
+ )
+
+ job = job_queue.submit(jt, **params)
+ return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
+
+
+# --- Load balancer status ---
+
+@router.get("/lb/status")
+async def lb_status():
+ """Get load balancer status for both nodes."""
+ from app.services.load_balancer import lb
+ return lb.get_status()
+
+
+@router.post("/lb/check")
+async def lb_health_check():
+ """Force a health check of both nodes."""
+ from app.services.load_balancer import lb
+ await lb.check_health()
+ return lb.get_status()
\ No newline at end of file
diff --git a/backend/app/api/routes/network.py b/backend/app/api/routes/network.py
new file mode 100644
index 0000000..65d47ad
--- /dev/null
+++ b/backend/app/api/routes/network.py
@@ -0,0 +1,28 @@
+"""Network topology API - host inventory endpoint."""
+
+import logging
+
+from fastapi import APIRouter, Depends, HTTPException, Query
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db import get_db
+from app.services.host_inventory import build_host_inventory
+
+logger = logging.getLogger(__name__)
+router = APIRouter(prefix="/api/network", tags=["network"])
+
+
+@router.get("/host-inventory")
+async def get_host_inventory(
+ hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
+ db: AsyncSession = Depends(get_db),
+):
+ """Build a deduplicated host inventory from all datasets in a hunt.
+
+ Returns unique hosts with hostname, IPs, OS, logged-in users, and
+ network connections derived from netstat/connection data.
+ """
+ result = await build_host_inventory(hunt_id, db)
+ if result["stats"]["total_hosts"] == 0:
+ return result
+ return result
\ No newline at end of file
diff --git a/backend/app/db/engine.py b/backend/app/db/engine.py
index afd2968..dbfc86c 100644
--- a/backend/app/db/engine.py
+++ b/backend/app/db/engine.py
@@ -3,6 +3,7 @@
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
"""
+from sqlalchemy import event
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
@@ -12,12 +13,32 @@ from sqlalchemy.orm import DeclarativeBase
from app.config import settings
-engine = create_async_engine(
- settings.DATABASE_URL,
+_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
+
+_engine_kwargs: dict = dict(
echo=settings.DEBUG,
future=True,
)
+if _is_sqlite:
+ _engine_kwargs["connect_args"] = {"timeout": 30}
+ _engine_kwargs["pool_size"] = 1
+ _engine_kwargs["max_overflow"] = 0
+
+engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
+
+
+@event.listens_for(engine.sync_engine, "connect")
+def _set_sqlite_pragmas(dbapi_conn, connection_record):
+ """Enable WAL mode and tune busy-timeout for SQLite connections."""
+ if _is_sqlite:
+ cursor = dbapi_conn.cursor()
+ cursor.execute("PRAGMA journal_mode=WAL")
+ cursor.execute("PRAGMA busy_timeout=5000")
+ cursor.execute("PRAGMA synchronous=NORMAL")
+ cursor.close()
+
+
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
@@ -51,4 +72,4 @@ async def init_db() -> None:
async def dispose_db() -> None:
"""Dispose of the engine connection pool."""
- await engine.dispose()
+ await engine.dispose()
\ No newline at end of file
diff --git a/backend/app/db/models.py b/backend/app/db/models.py
index f786b94..f2ab0ff 100644
--- a/backend/app/db/models.py
+++ b/backend/app/db/models.py
@@ -1,7 +1,7 @@
"""SQLAlchemy ORM models for ThreatHunt.
All persistent entities: datasets, hunts, conversations, annotations,
-hypotheses, enrichment results, and users.
+hypotheses, enrichment results, users, and AI analysis tables.
"""
import uuid
@@ -32,8 +32,7 @@ def _new_id() -> str:
return uuid.uuid4().hex
-# ── Users ──────────────────────────────────────────────────────────────
-
+# -- Users ---
class User(Base):
__tablename__ = "users"
@@ -42,17 +41,15 @@ class User(Base):
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
- role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
+ role: Mapped[str] = mapped_column(String(16), default="analyst")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
- # relationships
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
-# ── Hunts ──────────────────────────────────────────────────────────────
-
+# -- Hunts ---
class Hunt(Base):
__tablename__ = "hunts"
@@ -60,7 +57,7 @@ class Hunt(Base):
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
- status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
+ status: Mapped[str] = mapped_column(String(32), default="active")
owner_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
@@ -69,15 +66,15 @@ class Hunt(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
- # relationships
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
+ host_profiles: Mapped[list["HostProfile"]] = relationship(back_populates="hunt", lazy="noload")
+ reports: Mapped[list["HuntReport"]] = relationship(back_populates="hunt", lazy="noload")
-# ── Datasets ───────────────────────────────────────────────────────────
-
+# -- Datasets ---
class Dataset(Base):
__tablename__ = "datasets"
@@ -85,36 +82,44 @@ class Dataset(Base):
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
filename: Mapped[str] = mapped_column(String(512), nullable=False)
- source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
+ source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
row_count: Mapped[int] = mapped_column(Integer, default=0)
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
- ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
+ ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ # New Phase 1-2 columns
+ processing_status: Mapped[str] = mapped_column(String(20), default="ready")
+ artifact_type: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
+ error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
+
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
- # relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
rows: Mapped[list["DatasetRow"]] = relationship(
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
)
+ triage_results: Mapped[list["TriageResult"]] = relationship(
+ back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
+ )
__table_args__ = (
Index("ix_datasets_hunt", "hunt_id"),
+ Index("ix_datasets_status", "processing_status"),
)
class DatasetRow(Base):
- """Individual row from a CSV dataset, stored as JSON blob."""
__tablename__ = "dataset_rows"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
@@ -125,7 +130,6 @@ class DatasetRow(Base):
data: Mapped[dict] = mapped_column(JSON, nullable=False)
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
- # relationships
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
annotations: Mapped[list["Annotation"]] = relationship(
back_populates="row", lazy="noload"
@@ -137,8 +141,7 @@ class DatasetRow(Base):
)
-# ── Conversations ─────────────────────────────────────────────────────
-
+# -- Conversations ---
class Conversation(Base):
__tablename__ = "conversations"
@@ -156,7 +159,6 @@ class Conversation(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
- # relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
messages: Mapped[list["Message"]] = relationship(
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
@@ -171,16 +173,15 @@ class Message(Base):
conversation_id: Mapped[str] = mapped_column(
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
)
- role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
+ role: Mapped[str] = mapped_column(String(16), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
- node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
+ node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
- # relationships
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
__table_args__ = (
@@ -188,8 +189,7 @@ class Message(Base):
)
-# ── Annotations ───────────────────────────────────────────────────────
-
+# -- Annotations ---
class Annotation(Base):
__tablename__ = "annotations"
@@ -205,19 +205,14 @@ class Annotation(Base):
String(32), ForeignKey("users.id"), nullable=True
)
text: Mapped[str] = mapped_column(Text, nullable=False)
- severity: Mapped[str] = mapped_column(
- String(16), default="info"
- ) # info | low | medium | high | critical
- tag: Mapped[Optional[str]] = mapped_column(
- String(32), nullable=True
- ) # suspicious | benign | needs-review
+ severity: Mapped[str] = mapped_column(String(16), default="info")
+ tag: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
- # relationships
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
@@ -227,8 +222,7 @@ class Annotation(Base):
)
-# ── Hypotheses ────────────────────────────────────────────────────────
-
+# -- Hypotheses ---
class Hypothesis(Base):
__tablename__ = "hypotheses"
@@ -240,9 +234,7 @@ class Hypothesis(Base):
title: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
- status: Mapped[str] = mapped_column(
- String(16), default="draft"
- ) # draft | active | confirmed | rejected
+ status: Mapped[str] = mapped_column(String(16), default="draft")
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
@@ -250,7 +242,6 @@ class Hypothesis(Base):
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
- # relationships
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
__table_args__ = (
@@ -258,21 +249,16 @@ class Hypothesis(Base):
)
-# ── Enrichment Results ────────────────────────────────────────────────
-
+# -- Enrichment Results ---
class EnrichmentResult(Base):
__tablename__ = "enrichment_results"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
- ioc_type: Mapped[str] = mapped_column(
- String(32), nullable=False
- ) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
- source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
- verdict: Mapped[Optional[str]] = mapped_column(
- String(16), nullable=True
- ) # clean | suspicious | malicious | unknown
+ ioc_type: Mapped[str] = mapped_column(String(32), nullable=False)
+ source: Mapped[str] = mapped_column(String(32), nullable=False)
+ verdict: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
@@ -287,28 +273,24 @@ class EnrichmentResult(Base):
)
-# ── AUP Keyword Themes & Keywords ────────────────────────────────────
-
+# -- AUP Keyword Themes & Keywords ---
class KeywordTheme(Base):
- """A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
__tablename__ = "keyword_themes"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
- color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
+ color: Mapped[str] = mapped_column(String(16), default="#9e9e9e")
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
- is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
+ is_builtin: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
- # relationships
keywords: Mapped[list["Keyword"]] = relationship(
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
)
class Keyword(Base):
- """Individual keyword / pattern belonging to a theme."""
__tablename__ = "keywords"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
@@ -319,10 +301,102 @@ class Keyword(Base):
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
- # relationships
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
__table_args__ = (
Index("ix_keywords_theme", "theme_id"),
Index("ix_keywords_value", "value"),
)
+
+
+# -- AI Analysis Tables (Phase 2) ---
+
+class TriageResult(Base):
+ __tablename__ = "triage_results"
+
+ id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
+ dataset_id: Mapped[str] = mapped_column(
+ String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
+ )
+ row_start: Mapped[int] = mapped_column(Integer, nullable=False)
+ row_end: Mapped[int] = mapped_column(Integer, nullable=False)
+ risk_score: Mapped[float] = mapped_column(Float, default=0.0)
+ verdict: Mapped[str] = mapped_column(String(20), default="pending")
+ findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ suspicious_indicators: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
+ node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
+
+ dataset: Mapped["Dataset"] = relationship(back_populates="triage_results")
+
+
+class HostProfile(Base):
+ __tablename__ = "host_profiles"
+
+ id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
+ hunt_id: Mapped[str] = mapped_column(
+ String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
+ )
+ hostname: Mapped[str] = mapped_column(String(256), nullable=False)
+ fqdn: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
+ client_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ risk_score: Mapped[float] = mapped_column(Float, default=0.0)
+ risk_level: Mapped[str] = mapped_column(String(20), default="unknown")
+ artifact_summary: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
+ timeline_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ suspicious_findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ llm_analysis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
+ node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
+ updated_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
+ )
+
+ hunt: Mapped["Hunt"] = relationship(back_populates="host_profiles")
+
+
+class HuntReport(Base):
+ __tablename__ = "hunt_reports"
+
+ id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
+ hunt_id: Mapped[str] = mapped_column(
+ String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
+ )
+ status: Mapped[str] = mapped_column(String(20), default="pending")
+ exec_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ full_report: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ recommendations: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ mitre_mapping: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
+ ioc_table: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ host_risk_summary: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ models_used: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
+ generation_time_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
+ updated_at: Mapped[datetime] = mapped_column(
+ DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
+ )
+
+ hunt: Mapped["Hunt"] = relationship(back_populates="reports")
+
+
+class AnomalyResult(Base):
+ __tablename__ = "anomaly_results"
+
+ id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
+ dataset_id: Mapped[str] = mapped_column(
+ String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
+ )
+ row_id: Mapped[Optional[int]] = mapped_column(
+ Integer, ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True
+ )
+ anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
+ distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
+ explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
\ No newline at end of file
diff --git a/backend/app/main.py b/backend/app/main.py
index bd75ed7..6c81ddb 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -1,10 +1,12 @@
"""ThreatHunt backend application.
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
-annotation/hypothesis routes. DB tables are auto-created on startup.
+annotation/hypothesis routes, analysis routes, network routes, job queue,
+load balancer. DB tables are auto-created on startup.
"""
import logging
+import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
@@ -21,6 +23,8 @@ from app.api.routes.correlation import router as correlation_router
from app.api.routes.reports import router as reports_router
from app.api.routes.auth import router as auth_router
from app.api.routes.keywords import router as keywords_router
+from app.api.routes.analysis import router as analysis_router
+from app.api.routes.network import router as network_router
logger = logging.getLogger(__name__)
@@ -28,17 +32,45 @@ logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup / shutdown lifecycle."""
- logger.info("Starting ThreatHunt API …")
+ logger.info("Starting ThreatHunt API ...")
await init_db()
logger.info("Database initialised")
+
+ # Ensure uploads directory exists
+ os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
+ logger.info("Upload dir: %s", os.path.abspath(settings.UPLOAD_DIR))
+
# Seed default AUP keyword themes
from app.db import async_session_factory
from app.services.keyword_defaults import seed_defaults
async with async_session_factory() as seed_db:
await seed_defaults(seed_db)
logger.info("AUP keyword defaults checked")
+
+ # Start job queue (Phase 10)
+ from app.services.job_queue import job_queue, register_all_handlers
+ register_all_handlers()
+ await job_queue.start()
+ logger.info("Job queue started (%d workers)", job_queue._max_workers)
+
+ # Start load balancer health loop (Phase 10)
+ from app.services.load_balancer import lb
+ await lb.start_health_loop(interval=30.0)
+ logger.info("Load balancer health loop started")
+
yield
- logger.info("Shutting down …")
+
+ logger.info("Shutting down ...")
+ # Stop job queue
+ from app.services.job_queue import job_queue as jq
+ await jq.stop()
+ logger.info("Job queue stopped")
+
+ # Stop load balancer
+ from app.services.load_balancer import lb as _lb
+ await _lb.stop_health_loop()
+ logger.info("Load balancer stopped")
+
from app.agents.providers_v2 import cleanup_client
from app.services.enrichment import enrichment_engine
await cleanup_client()
@@ -46,15 +78,13 @@ async def lifespan(app: FastAPI):
await dispose_db()
-# Create FastAPI application
app = FastAPI(
title="ThreatHunt API",
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
- version="0.3.0",
+ version=settings.APP_VERSION,
lifespan=lifespan,
)
-# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
@@ -74,11 +104,12 @@ app.include_router(enrichment_router)
app.include_router(correlation_router)
app.include_router(reports_router)
app.include_router(keywords_router)
+app.include_router(analysis_router)
+app.include_router(network_router)
@app.get("/", tags=["health"])
async def root():
- """API health check."""
return {
"service": "ThreatHunt API",
"version": settings.APP_VERSION,
@@ -89,4 +120,4 @@ async def root():
"roadrunner": settings.roadrunner_url,
"openwebui": settings.OPENWEBUI_URL,
},
- }
+ }
\ No newline at end of file
diff --git a/backend/app/services/anomaly_detector.py b/backend/app/services/anomaly_detector.py
new file mode 100644
index 0000000..69b7593
--- /dev/null
+++ b/backend/app/services/anomaly_detector.py
@@ -0,0 +1,199 @@
+"""Embedding-based anomaly detection using Roadrunner's bge-m3 model.
+
+Converts dataset rows to embeddings, clusters them, and flags outliers
+that deviate significantly from the cluster centroids. Uses cosine
+distance and simple k-means-like centroid computation.
+"""
+
+import asyncio
+import json
+import logging
+import math
+from typing import Optional
+
+import httpx
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import settings
+from app.db import async_session_factory
+from app.db.models import AnomalyResult, Dataset, DatasetRow
+
+logger = logging.getLogger(__name__)
+
+EMBED_URL = f"{settings.roadrunner_url}/api/embed"
+EMBED_MODEL = "bge-m3"
+BATCH_SIZE = 32 # rows per embedding batch
+MAX_ROWS = 2000 # cap for anomaly detection
+
+# --- math helpers (no numpy required) ---
+
+def _dot(a: list[float], b: list[float]) -> float:
+ return sum(x * y for x, y in zip(a, b))
+
+
+def _norm(v: list[float]) -> float:
+ return math.sqrt(sum(x * x for x in v))
+
+
+def _cosine_distance(a: list[float], b: list[float]) -> float:
+ na, nb = _norm(a), _norm(b)
+ if na == 0 or nb == 0:
+ return 1.0
+ return 1.0 - _dot(a, b) / (na * nb)
+
+
+def _mean_vector(vectors: list[list[float]]) -> list[float]:
+ if not vectors:
+ return []
+ dim = len(vectors[0])
+ n = len(vectors)
+ return [sum(v[i] for v in vectors) / n for i in range(dim)]
+
+
+def _row_to_text(data: dict) -> str:
+ """Flatten a row dict to a single string for embedding."""
+ parts = []
+ for k, v in data.items():
+ sv = str(v).strip()
+ if sv and sv.lower() not in ('none', 'null', ''):
+ parts.append(f"{k}={sv}")
+ return " | ".join(parts)[:2000] # cap length
+
+
+async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
+ """Get embeddings from Roadrunner's Ollama API."""
+ resp = await client.post(
+ EMBED_URL,
+ json={"model": EMBED_MODEL, "input": texts},
+ timeout=120.0,
+ )
+ resp.raise_for_status()
+ data = resp.json()
+ # Ollama returns {"embeddings": [[...], ...]}
+ return data.get("embeddings", [])
+
+
+def _simple_cluster(
+ embeddings: list[list[float]],
+ k: int = 3,
+ max_iter: int = 20,
+) -> tuple[list[int], list[list[float]]]:
+ """Simple k-means clustering (no numpy dependency).
+
+ Returns (assignments, centroids).
+ """
+ n = len(embeddings)
+ if n <= k:
+ return list(range(n)), embeddings[:]
+
+ # Init centroids: evenly spaced indices
+ step = max(n // k, 1)
+ centroids = [embeddings[i * step % n] for i in range(k)]
+ assignments = [0] * n
+
+ for _ in range(max_iter):
+ # Assign to nearest centroid
+ new_assignments = []
+ for emb in embeddings:
+ dists = [_cosine_distance(emb, c) for c in centroids]
+ new_assignments.append(dists.index(min(dists)))
+
+ if new_assignments == assignments:
+ break
+ assignments = new_assignments
+
+ # Recompute centroids
+ for ci in range(k):
+ members = [embeddings[j] for j in range(n) if assignments[j] == ci]
+ if members:
+ centroids[ci] = _mean_vector(members)
+
+ return assignments, centroids
+
+
+async def detect_anomalies(
+ dataset_id: str,
+ k: int = 3,
+ outlier_threshold: float = 0.35,
+) -> list[dict]:
+ """Run embedding-based anomaly detection on a dataset.
+
+ 1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers.
+ """
+ async with async_session_factory() as db:
+ # Load rows
+ result = await db.execute(
+ select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_index)
+ .limit(MAX_ROWS)
+ )
+ rows = result.all()
+ if not rows:
+ logger.info("No rows for anomaly detection in dataset %s", dataset_id)
+ return []
+
+ row_ids = [r[0] for r in rows]
+ row_indices = [r[1] for r in rows]
+ texts = [_row_to_text(r[2]) for r in rows]
+
+ logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL)
+
+ # Embed in batches
+ all_embeddings: list[list[float]] = []
+ async with httpx.AsyncClient() as client:
+ for i in range(0, len(texts), BATCH_SIZE):
+ batch = texts[i : i + BATCH_SIZE]
+ try:
+ embs = await _embed_batch(batch, client)
+ all_embeddings.extend(embs)
+ except Exception as e:
+ logger.error("Embedding batch %d failed: %s", i, e)
+ # Fill with zeros so indices stay aligned
+ all_embeddings.extend([[0.0] * 1024] * len(batch))
+
+ if not all_embeddings or len(all_embeddings) != len(texts):
+ logger.error("Embedding count mismatch")
+ return []
+
+ # Cluster
+ actual_k = min(k, len(all_embeddings))
+ assignments, centroids = _simple_cluster(all_embeddings, k=actual_k)
+
+ # Compute distances from centroid
+ anomalies: list[dict] = []
+ for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)):
+ dist = _cosine_distance(emb, centroids[cluster_id])
+ is_outlier = dist > outlier_threshold
+ anomalies.append({
+ "row_id": row_ids[idx],
+ "row_index": row_indices[idx],
+ "anomaly_score": round(dist, 4),
+ "distance_from_centroid": round(dist, 4),
+ "cluster_id": cluster_id,
+ "is_outlier": is_outlier,
+ })
+
+ # Save to DB
+ outlier_count = 0
+ for a in anomalies:
+ ar = AnomalyResult(
+ dataset_id=dataset_id,
+ row_id=a["row_id"],
+ anomaly_score=a["anomaly_score"],
+ distance_from_centroid=a["distance_from_centroid"],
+ cluster_id=a["cluster_id"],
+ is_outlier=a["is_outlier"],
+ )
+ db.add(ar)
+ if a["is_outlier"]:
+ outlier_count += 1
+
+ await db.commit()
+ logger.info(
+ "Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)",
+ len(anomalies), outlier_count, outlier_threshold,
+ )
+
+ return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True)
\ No newline at end of file
diff --git a/backend/app/services/artifact_classifier.py b/backend/app/services/artifact_classifier.py
new file mode 100644
index 0000000..d88bf32
--- /dev/null
+++ b/backend/app/services/artifact_classifier.py
@@ -0,0 +1,81 @@
+"""Artifact classifier - identify Velociraptor artifact types from CSV headers."""
+
+from __future__ import annotations
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+# (required_columns, artifact_type)
+FINGERPRINTS: list[tuple[set[str], str]] = [
+ ({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"),
+ ({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"),
+ ({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"),
+ ({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"),
+ ({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"),
+ ({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"),
+ ({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"),
+ ({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"),
+ ({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"),
+ ({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
+ ({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
+ ({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"),
+ ({"Key", "Type", "Value"}, "Windows.Registry.Finder"),
+ ({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"),
+ ({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"),
+ ({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"),
+ ({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"),
+ ({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"),
+ ({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"),
+ ({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"),
+ ({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"),
+ ({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"),
+ ({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"),
+ ({"os_info.hostname", "os_info.system"}, "Server.Information.Client"),
+ ({"ClientId", "os_info.fqdn"}, "Server.Information.Client"),
+ ({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"),
+ ({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"),
+ ({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"),
+ ({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"),
+ ({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"),
+ ({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"),
+]
+
+VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"}
+
+CATEGORY_MAP = {
+ "Pslist": "process",
+ "Netstat": "network",
+ "Services": "persistence",
+ "FileFinder": "filesystem",
+ "Prefetch": "execution",
+ "Registry": "persistence",
+ "EvtxHunter": "eventlog",
+ "EventLogs": "eventlog",
+ "Autoruns": "persistence",
+ "DNS": "network",
+ "Hash": "filesystem",
+ "TaskScheduler": "persistence",
+ "Users": "account",
+ "Client": "system",
+ "WMI": "persistence",
+ "ArpCache": "network",
+ "BrowserHistory": "application",
+}
+
+
+def classify_artifact(columns: list[str]) -> str:
+ col_set = set(columns)
+ for required, artifact_type in FINGERPRINTS:
+ if required.issubset(col_set):
+ return artifact_type
+ if VELOCIRAPTOR_META.intersection(col_set):
+ return "Velociraptor.Unknown"
+ return "Unknown"
+
+
+def get_artifact_category(artifact_type: str) -> str:
+ for key, category in CATEGORY_MAP.items():
+ if key.lower() in artifact_type.lower():
+ return category
+ return "unknown"
\ No newline at end of file
diff --git a/backend/app/services/data_query.py b/backend/app/services/data_query.py
new file mode 100644
index 0000000..f13cdf5
--- /dev/null
+++ b/backend/app/services/data_query.py
@@ -0,0 +1,238 @@
+"""Natural-language data query service with SSE streaming.
+
+Lets analysts ask questions about dataset rows in plain English.
+Routes to fast model (Roadrunner) for quick queries, heavy model (Wile)
+for deep analysis. Supports streaming via OllamaProvider.generate_stream().
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import time
+from typing import AsyncIterator
+
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.config import settings
+from app.db import async_session_factory
+from app.db.models import Dataset, DatasetRow
+
+logger = logging.getLogger(__name__)
+
+# Maximum rows to include in context window
+MAX_CONTEXT_ROWS = 60
+MAX_ROW_TEXT_CHARS = 300
+
+
+def _rows_to_text(rows: list[dict], columns: list[str]) -> str:
+ """Convert dataset rows to a compact text table for the LLM context."""
+ if not rows:
+ return "(no rows)"
+ # Header
+ header = " | ".join(columns[:20]) # cap columns to avoid overflow
+ lines = [header, "-" * min(len(header), 120)]
+ for row in rows[:MAX_CONTEXT_ROWS]:
+ vals = []
+ for c in columns[:20]:
+ v = str(row.get(c, ""))
+ if len(v) > 80:
+ v = v[:77] + "..."
+ vals.append(v)
+ line = " | ".join(vals)
+ if len(line) > MAX_ROW_TEXT_CHARS:
+ line = line[:MAX_ROW_TEXT_CHARS] + "..."
+ lines.append(line)
+ return "\n".join(lines)
+
+
+QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt.
+You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.).
+
+Your job:
+- Answer the analyst's question about this data accurately and concisely
+- Point out suspicious patterns, anomalies, or indicators of compromise
+- Reference MITRE ATT&CK techniques when relevant
+- Suggest follow-up queries or pivots
+- If you cannot answer from the data provided, say so clearly
+
+Rules:
+- Be factual - only reference data you can see
+- Use forensic terminology appropriate for SOC/DFIR analysts
+- Format your answer with clear sections using markdown
+- If the data seems benign, say so - do not fabricate threats"""
+
+
+async def _load_dataset_context(
+ dataset_id: str,
+ db: AsyncSession,
+ sample_size: int = MAX_CONTEXT_ROWS,
+) -> tuple[dict, str, int]:
+ """Load dataset metadata + sample rows for context.
+
+ Returns (metadata_dict, rows_text, total_row_count).
+ """
+ ds = await db.get(Dataset, dataset_id)
+ if not ds:
+ raise ValueError(f"Dataset {dataset_id} not found")
+
+ # Get total count
+ count_q = await db.execute(
+ select(func.count()).where(DatasetRow.dataset_id == dataset_id)
+ )
+ total = count_q.scalar() or 0
+
+ # Sample rows - get first batch + some from the middle
+ half = sample_size // 2
+ result = await db.execute(
+ select(DatasetRow)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_index)
+ .limit(half)
+ )
+ first_rows = result.scalars().all()
+
+ # If dataset is large, also sample from the middle
+ middle_rows = []
+ if total > sample_size:
+ mid_offset = total // 2
+ result2 = await db.execute(
+ select(DatasetRow)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_index)
+ .offset(mid_offset)
+ .limit(sample_size - half)
+ )
+ middle_rows = result2.scalars().all()
+ else:
+ result2 = await db.execute(
+ select(DatasetRow)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_index)
+ .offset(half)
+ .limit(sample_size - half)
+ )
+ middle_rows = result2.scalars().all()
+
+ all_rows = first_rows + middle_rows
+ row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows]
+
+ columns = list(ds.column_schema.keys()) if ds.column_schema else []
+ if not columns and row_dicts:
+ columns = list(row_dicts[0].keys())
+
+ rows_text = _rows_to_text(row_dicts, columns)
+
+ metadata = {
+ "name": ds.name,
+ "filename": ds.filename,
+ "source_tool": ds.source_tool,
+ "artifact_type": getattr(ds, "artifact_type", None),
+ "row_count": total,
+ "columns": columns[:30],
+ "sample_rows_shown": len(all_rows),
+ }
+ return metadata, rows_text, total
+
+
+async def query_dataset(
+ dataset_id: str,
+ question: str,
+ mode: str = "quick",
+) -> str:
+ """Non-streaming query: returns full answer text."""
+ from app.agents.providers_v2 import OllamaProvider, Node
+
+ async with async_session_factory() as db:
+ meta, rows_text, total = await _load_dataset_context(dataset_id, db)
+
+ prompt = _build_prompt(question, meta, rows_text, total)
+
+ if mode == "deep":
+ provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
+ max_tokens = 4096
+ else:
+ provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
+ max_tokens = 2048
+
+ result = await provider.generate(
+ prompt,
+ system=QUERY_SYSTEM_PROMPT,
+ max_tokens=max_tokens,
+ temperature=0.3,
+ )
+ return result.get("response", "No response generated.")
+
+
+async def query_dataset_stream(
+ dataset_id: str,
+ question: str,
+ mode: str = "quick",
+) -> AsyncIterator[str]:
+ """Streaming query: yields SSE-formatted events."""
+ from app.agents.providers_v2 import OllamaProvider, Node
+
+ start = time.monotonic()
+
+ # Send initial metadata event
+ yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n"
+
+ async with async_session_factory() as db:
+ meta, rows_text, total = await _load_dataset_context(dataset_id, db)
+
+ yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n"
+ yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n"
+
+ prompt = _build_prompt(question, meta, rows_text, total)
+
+ if mode == "deep":
+ provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
+ max_tokens = 4096
+ model_name = settings.DEFAULT_HEAVY_MODEL
+ node_name = "wile"
+ else:
+ provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
+ max_tokens = 2048
+ model_name = settings.DEFAULT_FAST_MODEL
+ node_name = "roadrunner"
+
+ # Stream tokens
+ token_count = 0
+ try:
+ async for token in provider.generate_stream(
+ prompt,
+ system=QUERY_SYSTEM_PROMPT,
+ max_tokens=max_tokens,
+ temperature=0.3,
+ ):
+ token_count += 1
+ yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
+ except Exception as e:
+ logger.error(f"Streaming error: {e}")
+ yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
+
+ elapsed_ms = int((time.monotonic() - start) * 1000)
+ yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n"
+
+
+def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str:
+ """Construct the full prompt with data context."""
+ parts = [
+ f"## Dataset: {meta['name']}",
+ f"- Source: {meta.get('source_tool', 'unknown')}",
+ f"- Artifact type: {meta.get('artifact_type', 'unknown')}",
+ f"- Total rows: {total}",
+ f"- Columns: {', '.join(meta.get('columns', []))}",
+ f"- Showing {meta['sample_rows_shown']} sample rows below",
+ "",
+ "## Sample Data",
+ "```",
+ rows_text,
+ "```",
+ "",
+ f"## Analyst Question",
+ question,
+ ]
+ return "\n".join(parts)
\ No newline at end of file
diff --git a/backend/app/services/host_inventory.py b/backend/app/services/host_inventory.py
new file mode 100644
index 0000000..32f0146
--- /dev/null
+++ b/backend/app/services/host_inventory.py
@@ -0,0 +1,290 @@
+"""Host Inventory Service - builds a deduplicated host-centric network view.
+
+Scans all datasets in a hunt to identify unique hosts, their IPs, OS,
+logged-in users, and network connections between them.
+"""
+
+import re
+import logging
+from collections import defaultdict
+from typing import Any
+
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db.models import Dataset, DatasetRow
+
+logger = logging.getLogger(__name__)
+
+# --- Column-name patterns (Velociraptor + generic forensic tools) ---
+
+_HOST_ID_RE = re.compile(
+ r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I)
+_FQDN_RE = re.compile(
+ r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|'
+ r'system_?name|machine_?name|nodename|workstation)$', re.I)
+_USERNAME_RE = re.compile(
+ r'^(user|username|user_?name|logon_?name|account_?name|owner|'
+ r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I)
+_LOCAL_IP_RE = re.compile(
+ r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I)
+_REMOTE_IP_RE = re.compile(
+ r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I)
+_REMOTE_PORT_RE = re.compile(
+ r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I)
+_OS_RE = re.compile(
+ r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I)
+_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$')
+
+_IGNORE_IPS = frozenset({
+ '0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null',
+})
+_SYSTEM_DOMAINS = frozenset({
+ 'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER',
+})
+_SYSTEM_USERS = frozenset({
+ 'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE',
+ 'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3',
+})
+
+
+def _is_valid_ip(v: str) -> bool:
+ if not v or v in _IGNORE_IPS:
+ return False
+ return bool(_IP_VALID_RE.match(v))
+
+
+def _clean(v: Any) -> str:
+ s = str(v or '').strip()
+ return s if s and s not in ('-', 'None', 'null', '') else ''
+
+
+_SYSTEM_USER_RE = re.compile(
+ r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I)
+
+
+def _extract_username(raw: str) -> str:
+ """Clean username, stripping domain prefixes and filtering system accounts."""
+ if not raw:
+ return ''
+ name = raw.strip()
+ if '\\' in name:
+ domain, _, name = name.rpartition('\\')
+ name = name.strip()
+ if domain.strip().upper() in _SYSTEM_DOMAINS:
+ if not name or _SYSTEM_USER_RE.match(name):
+ return ''
+ if _SYSTEM_USER_RE.match(name):
+ return ''
+ return name or ''
+
+
+def _infer_os(fqdn: str) -> str:
+ u = fqdn.upper()
+ if 'W10-' in u or 'WIN10' in u:
+ return 'Windows 10'
+ if 'W11-' in u or 'WIN11' in u:
+ return 'Windows 11'
+ if 'W7-' in u or 'WIN7' in u:
+ return 'Windows 7'
+ if 'SRV' in u or 'SERVER' in u or 'DC-' in u:
+ return 'Windows Server'
+ if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')):
+ return 'Linux'
+ if 'MAC' in u or 'DARWIN' in u:
+ return 'macOS'
+ return 'Windows'
+
+
+def _identify_columns(ds: Dataset) -> dict:
+ norm = ds.normalized_columns or {}
+ schema = ds.column_schema or {}
+ raw_cols = list(schema.keys()) if schema else list(norm.keys())
+
+ result = {
+ 'host_id': [], 'fqdn': [], 'username': [],
+ 'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [],
+ }
+
+ for col in raw_cols:
+ canonical = (norm.get(col) or '').lower()
+ lower = col.lower()
+
+ if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')):
+ result['host_id'].append(col)
+
+ if _FQDN_RE.match(lower) or canonical == 'fqdn':
+ result['fqdn'].append(col)
+
+ if _USERNAME_RE.match(lower) or canonical in ('username', 'user'):
+ result['username'].append(col)
+
+ if _LOCAL_IP_RE.match(lower):
+ result['local_ip'].append(col)
+ elif _REMOTE_IP_RE.match(lower):
+ result['remote_ip'].append(col)
+
+ if _REMOTE_PORT_RE.match(lower):
+ result['remote_port'].append(col)
+
+ if _OS_RE.match(lower) or canonical == 'os':
+ result['os'].append(col)
+
+ return result
+
+
+async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
+ """Build a deduplicated host inventory from all datasets in a hunt.
+
+ Returns dict with 'hosts', 'connections', and 'stats'.
+ Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
+ """
+ ds_result = await db.execute(
+ select(Dataset).where(Dataset.hunt_id == hunt_id)
+ )
+ all_datasets = ds_result.scalars().all()
+
+ if not all_datasets:
+ return {"hosts": [], "connections": [], "stats": {
+ "total_hosts": 0, "total_datasets_scanned": 0,
+ "total_rows_scanned": 0,
+ }}
+
+ hosts: dict[str, dict] = {} # fqdn -> host record
+ ip_to_host: dict[str, str] = {} # local-ip -> fqdn
+ connections: dict[tuple, int] = defaultdict(int)
+ total_rows = 0
+ ds_with_hosts = 0
+
+ for ds in all_datasets:
+ cols = _identify_columns(ds)
+ if not cols['fqdn'] and not cols['host_id']:
+ continue
+ ds_with_hosts += 1
+
+ batch_size = 5000
+ offset = 0
+ while True:
+ rr = await db.execute(
+ select(DatasetRow)
+ .where(DatasetRow.dataset_id == ds.id)
+ .order_by(DatasetRow.row_index)
+ .offset(offset).limit(batch_size)
+ )
+ rows = rr.scalars().all()
+ if not rows:
+ break
+
+ for ro in rows:
+ data = ro.data or {}
+ total_rows += 1
+
+ fqdn = ''
+ for c in cols['fqdn']:
+ fqdn = _clean(data.get(c))
+ if fqdn:
+ break
+ client_id = ''
+ for c in cols['host_id']:
+ client_id = _clean(data.get(c))
+ if client_id:
+ break
+
+ if not fqdn and not client_id:
+ continue
+
+ host_key = fqdn or client_id
+
+ if host_key not in hosts:
+ short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
+ hosts[host_key] = {
+ 'id': host_key,
+ 'hostname': short or client_id,
+ 'fqdn': fqdn,
+ 'client_id': client_id,
+ 'ips': set(),
+ 'os': '',
+ 'users': set(),
+ 'datasets': set(),
+ 'row_count': 0,
+ }
+
+ h = hosts[host_key]
+ h['datasets'].add(ds.name)
+ h['row_count'] += 1
+ if client_id and not h['client_id']:
+ h['client_id'] = client_id
+
+ for c in cols['username']:
+ u = _extract_username(_clean(data.get(c)))
+ if u:
+ h['users'].add(u)
+
+ for c in cols['local_ip']:
+ ip = _clean(data.get(c))
+ if _is_valid_ip(ip):
+ h['ips'].add(ip)
+ ip_to_host[ip] = host_key
+
+ for c in cols['os']:
+ ov = _clean(data.get(c))
+ if ov and not h['os']:
+ h['os'] = ov
+
+ for c in cols['remote_ip']:
+ rip = _clean(data.get(c))
+ if _is_valid_ip(rip):
+ rport = ''
+ for pc in cols['remote_port']:
+ rport = _clean(data.get(pc))
+ if rport:
+ break
+ connections[(host_key, rip, rport)] += 1
+
+ offset += batch_size
+ if len(rows) < batch_size:
+ break
+
+ # Post-process hosts
+ for h in hosts.values():
+ if not h['os'] and h['fqdn']:
+ h['os'] = _infer_os(h['fqdn'])
+ h['ips'] = sorted(h['ips'])
+ h['users'] = sorted(h['users'])
+ h['datasets'] = sorted(h['datasets'])
+
+ # Build connections, resolving IPs to host keys
+ conn_list = []
+ seen = set()
+ for (src, dst_ip, dst_port), cnt in connections.items():
+ if dst_ip in _IGNORE_IPS:
+ continue
+ dst_host = ip_to_host.get(dst_ip, '')
+ if dst_host == src:
+ continue
+ key = tuple(sorted([src, dst_host or dst_ip]))
+ if key in seen:
+ continue
+ seen.add(key)
+ conn_list.append({
+ 'source': src,
+ 'target': dst_host or dst_ip,
+ 'target_ip': dst_ip,
+ 'port': dst_port,
+ 'count': cnt,
+ })
+
+ host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
+
+ return {
+ "hosts": host_list,
+ "connections": conn_list,
+ "stats": {
+ "total_hosts": len(host_list),
+ "total_datasets_scanned": len(all_datasets),
+ "datasets_with_hosts": ds_with_hosts,
+ "total_rows_scanned": total_rows,
+ "hosts_with_ips": sum(1 for h in host_list if h['ips']),
+ "hosts_with_users": sum(1 for h in host_list if h['users']),
+ },
+ }
\ No newline at end of file
diff --git a/backend/app/services/host_profiler.py b/backend/app/services/host_profiler.py
new file mode 100644
index 0000000..c4f8bb1
--- /dev/null
+++ b/backend/app/services/host_profiler.py
@@ -0,0 +1,198 @@
+"""Host profiler - per-host deep threat analysis via Wile heavy models."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+
+import httpx
+from sqlalchemy import select
+
+from app.config import settings
+from app.db.engine import async_session
+from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult
+
+logger = logging.getLogger(__name__)
+
+HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
+WILE_URL = f"{settings.wile_url}/api/generate"
+
+
+async def _get_triage_summary(db, dataset_id: str) -> str:
+ result = await db.execute(
+ select(TriageResult)
+ .where(TriageResult.dataset_id == dataset_id)
+ .where(TriageResult.risk_score >= 3.0)
+ .order_by(TriageResult.risk_score.desc())
+ .limit(10)
+ )
+ triages = result.scalars().all()
+ if not triages:
+ return "No significant triage findings."
+ lines = []
+ for t in triages:
+ lines.append(
+ f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} "
+ f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}"
+ )
+ return "\n".join(lines)
+
+
+async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict:
+ result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
+ datasets = result.scalars().all()
+
+ host_data: dict[str, list[dict]] = {}
+ triage_parts: list[str] = []
+
+ for ds in datasets:
+ artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
+ rows_result = await db.execute(
+ select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500)
+ )
+ rows = rows_result.scalars().all()
+
+ matching = []
+ for r in rows:
+ data = r.normalized_data or r.data
+ row_host = (
+ data.get("hostname", "") or data.get("Fqdn", "")
+ or data.get("ClientId", "") or data.get("client_id", "")
+ )
+ if hostname.lower() in str(row_host).lower():
+ matching.append(data)
+ elif fqdn and fqdn.lower() in str(row_host).lower():
+ matching.append(data)
+
+ if matching:
+ host_data[artifact_type] = matching[:50]
+ triage_info = await _get_triage_summary(db, ds.id)
+ triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}")
+
+ return {
+ "artifacts": host_data,
+ "triage_summary": "\n".join(triage_parts) or "No triage data.",
+ "artifact_count": sum(len(v) for v in host_data.values()),
+ }
+
+
+async def profile_host(
+ hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None,
+) -> None:
+ logger.info("Profiling host %s in hunt %s", hostname, hunt_id)
+
+ async with async_session() as db:
+ host_data = await _collect_host_data(db, hunt_id, hostname, fqdn)
+ if host_data["artifact_count"] == 0:
+ logger.info("No data found for host %s, skipping", hostname)
+ return
+
+ system_prompt = (
+ "You are a senior threat hunting analyst performing deep host analysis.\n"
+ "You receive consolidated forensic artifacts and prior triage results for a single host.\n\n"
+ "Provide a comprehensive host threat profile as JSON:\n"
+ "- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n"
+ "- risk_level: low/medium/high/critical\n"
+ "- suspicious_findings: list of specific concerns\n"
+ "- mitre_techniques: list of MITRE ATT&CK technique IDs\n"
+ "- timeline_summary: brief timeline of suspicious activity\n"
+ "- analysis: detailed narrative assessment\n\n"
+ "Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n"
+ "Respond with valid JSON only."
+ )
+
+ artifact_summary = {}
+ for art_type, rows in host_data["artifacts"].items():
+ artifact_summary[art_type] = [
+ {k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20]
+ ]
+
+ prompt = (
+ f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n"
+ f"## Prior Triage Results\n{host_data['triage_summary']}\n\n"
+ f"## Artifact Data ({host_data['artifact_count']} total rows)\n"
+ f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n"
+ "Provide your comprehensive host threat profile as JSON."
+ )
+
+ try:
+ async with httpx.AsyncClient(timeout=300.0) as client:
+ resp = await client.post(
+ WILE_URL,
+ json={
+ "model": HEAVY_MODEL,
+ "prompt": prompt,
+ "system": system_prompt,
+ "stream": False,
+ "options": {"temperature": 0.3, "num_predict": 4096},
+ },
+ )
+ resp.raise_for_status()
+ llm_text = resp.json().get("response", "")
+
+ from app.services.triage import _parse_llm_response
+ parsed = _parse_llm_response(llm_text)
+
+ profile = HostProfile(
+ hunt_id=hunt_id,
+ hostname=hostname,
+ fqdn=fqdn,
+ client_id=client_id,
+ risk_score=float(parsed.get("risk_score", 0.0)),
+ risk_level=parsed.get("risk_level", "low"),
+ artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()},
+ timeline_summary=parsed.get("timeline_summary", ""),
+ suspicious_findings=parsed.get("suspicious_findings", []),
+ mitre_techniques=parsed.get("mitre_techniques", []),
+ llm_analysis=parsed.get("analysis", llm_text[:5000]),
+ model_used=HEAVY_MODEL,
+ node_used="wile",
+ )
+ db.add(profile)
+ await db.commit()
+ logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
+
+ except Exception as e:
+ logger.error("Failed to profile host %s: %s", hostname, e)
+ profile = HostProfile(
+ hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
+ risk_score=0.0, risk_level="unknown",
+ llm_analysis=f"Error: {e}",
+ model_used=HEAVY_MODEL, node_used="wile",
+ )
+ db.add(profile)
+ await db.commit()
+
+
+async def profile_all_hosts(hunt_id: str) -> None:
+ logger.info("Starting host profiling for hunt %s", hunt_id)
+
+ async with async_session() as db:
+ result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
+ datasets = result.scalars().all()
+
+ hostnames: dict[str, str | None] = {}
+ for ds in datasets:
+ rows_result = await db.execute(
+ select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000)
+ )
+ for r in rows_result.scalars().all():
+ data = r.normalized_data or r.data
+ host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname")
+ if host and str(host).strip():
+ h = str(host).strip()
+ if h not in hostnames:
+ hostnames[h] = data.get("fqdn") or data.get("Fqdn")
+
+ logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
+
+ semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
+
+ async def _bounded(hostname: str, fqdn: str | None):
+ async with semaphore:
+ await profile_host(hunt_id, hostname, fqdn)
+
+ tasks = [_bounded(h, f) for h, f in hostnames.items()]
+ await asyncio.gather(*tasks, return_exceptions=True)
+ logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames))
\ No newline at end of file
diff --git a/backend/app/services/ioc_extractor.py b/backend/app/services/ioc_extractor.py
new file mode 100644
index 0000000..889bfa7
--- /dev/null
+++ b/backend/app/services/ioc_extractor.py
@@ -0,0 +1,210 @@
+"""IOC extraction service extract indicators of compromise from dataset rows.
+
+Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes,
+email addresses, URLs, and file paths that look suspicious.
+"""
+
+import re
+import logging
+from collections import defaultdict
+from typing import Optional
+
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.db.models import Dataset, DatasetRow
+
+logger = logging.getLogger(__name__)
+
+# Patterns
+
+_IPV4 = re.compile(
+ r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b'
+)
+_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b')
+_DOMAIN = re.compile(
+ r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)'
+ r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|'
+ r'online|site|club|win|work|download|stream|gdn|bid|review|racing|'
+ r'loan|date|faith|accountant|cricket|science|trade|party|men)\b',
+ re.IGNORECASE,
+)
+_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b')
+_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b')
+_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b')
+_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b')
+_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE)
+
+# Private / reserved IPs to skip
+_PRIVATE_NETS = re.compile(
+ r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)'
+)
+
+PATTERNS = {
+ 'ipv4': _IPV4,
+ 'ipv6': _IPV6,
+ 'domain': _DOMAIN,
+ 'md5': _MD5,
+ 'sha1': _SHA1,
+ 'sha256': _SHA256,
+ 'email': _EMAIL,
+ 'url': _URL,
+}
+
+
+def _is_private_ip(ip: str) -> bool:
+ return bool(_PRIVATE_NETS.match(ip))
+
+
+def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]:
+ """Extract all IOC types from a block of text."""
+ result: dict[str, set[str]] = defaultdict(set)
+ for ioc_type, pattern in PATTERNS.items():
+ for match in pattern.findall(text):
+ val = match.strip().lower() if ioc_type != 'url' else match.strip()
+ # Filter private IPs
+ if ioc_type == 'ipv4' and skip_private and _is_private_ip(val):
+ continue
+ # Filter hex strings that are too generic (< 32 chars not a hash)
+ result[ioc_type].add(val)
+ return result
+
+
+async def extract_iocs_from_dataset(
+ dataset_id: str,
+ db: AsyncSession,
+ max_rows: int = 5000,
+ skip_private: bool = True,
+) -> dict[str, list[str]]:
+ """Extract IOCs from all rows of a dataset.
+
+ Returns {ioc_type: [sorted unique values]}.
+ """
+ # Load rows in batches
+ all_iocs: dict[str, set[str]] = defaultdict(set)
+ offset = 0
+ batch_size = 500
+
+ while offset < max_rows:
+ result = await db.execute(
+ select(DatasetRow.data)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_index)
+ .offset(offset)
+ .limit(batch_size)
+ )
+ rows = result.scalars().all()
+ if not rows:
+ break
+
+ for data in rows:
+ # Flatten all values to a single string for scanning
+ text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data)
+ batch_iocs = extract_iocs_from_text(text, skip_private)
+ for ioc_type, values in batch_iocs.items():
+ all_iocs[ioc_type].update(values)
+
+ offset += batch_size
+
+ # Convert sets to sorted lists
+ return {k: sorted(v) for k, v in all_iocs.items() if v}
+
+
+async def extract_host_groups(
+ hunt_id: str,
+ db: AsyncSession,
+) -> list[dict]:
+ """Group all data by hostname across datasets in a hunt.
+
+ Returns a list of host group dicts with dataset count, total rows,
+ artifact types, and time range.
+ """
+ # Get all datasets for this hunt
+ result = await db.execute(
+ select(Dataset).where(Dataset.hunt_id == hunt_id)
+ )
+ ds_list = result.scalars().all()
+ if not ds_list:
+ return []
+
+ # Known host columns (check normalized data first, then raw)
+ HOST_COLS = [
+ 'hostname', 'host', 'computer_name', 'computername', 'system',
+ 'machine', 'device_name', 'devicename', 'endpoint',
+ 'ClientId', 'Fqdn', 'client_id', 'fqdn',
+ ]
+
+ hosts: dict[str, dict] = {}
+
+ for ds in ds_list:
+ # Sample first few rows to find host column
+ sample_result = await db.execute(
+ select(DatasetRow.data, DatasetRow.normalized_data)
+ .where(DatasetRow.dataset_id == ds.id)
+ .limit(5)
+ )
+ samples = sample_result.all()
+ if not samples:
+ continue
+
+ # Find which host column exists
+ host_col = None
+ for row_data, norm_data in samples:
+ check = norm_data if norm_data else row_data
+ if not isinstance(check, dict):
+ continue
+ for col in HOST_COLS:
+ if col in check and check[col]:
+ host_col = col
+ break
+ if host_col:
+ break
+
+ if not host_col:
+ continue
+
+ # Count rows per host in this dataset
+ all_rows_result = await db.execute(
+ select(DatasetRow.data, DatasetRow.normalized_data)
+ .where(DatasetRow.dataset_id == ds.id)
+ )
+ all_rows = all_rows_result.all()
+ for row_data, norm_data in all_rows:
+ check = norm_data if norm_data else row_data
+ if not isinstance(check, dict):
+ continue
+ host_val = check.get(host_col, '')
+ if not host_val or not isinstance(host_val, str):
+ continue
+ host_val = host_val.strip()
+ if not host_val:
+ continue
+
+ if host_val not in hosts:
+ hosts[host_val] = {
+ 'hostname': host_val,
+ 'dataset_ids': set(),
+ 'total_rows': 0,
+ 'artifact_types': set(),
+ 'first_seen': None,
+ 'last_seen': None,
+ }
+ hosts[host_val]['dataset_ids'].add(ds.id)
+ hosts[host_val]['total_rows'] += 1
+ if ds.artifact_type:
+ hosts[host_val]['artifact_types'].add(ds.artifact_type)
+
+ # Convert to output format
+ result_list = []
+ for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True):
+ result_list.append({
+ 'hostname': h['hostname'],
+ 'dataset_count': len(h['dataset_ids']),
+ 'total_rows': h['total_rows'],
+ 'artifact_types': sorted(h['artifact_types']),
+ 'first_seen': None, # TODO: extract from timestamp columns
+ 'last_seen': None,
+ 'risk_score': None, # TODO: link to host profiles
+ })
+
+ return result_list
\ No newline at end of file
diff --git a/backend/app/services/job_queue.py b/backend/app/services/job_queue.py
new file mode 100644
index 0000000..f218cdf
--- /dev/null
+++ b/backend/app/services/job_queue.py
@@ -0,0 +1,316 @@
+"""Async job queue for background AI tasks.
+
+Manages triage, profiling, report generation, anomaly detection,
+and data queries as trackable jobs with status, progress, and
+cancellation support.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+import uuid
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Callable, Coroutine, Optional
+
+logger = logging.getLogger(__name__)
+
+
+class JobStatus(str, Enum):
+ QUEUED = "queued"
+ RUNNING = "running"
+ COMPLETED = "completed"
+ FAILED = "failed"
+ CANCELLED = "cancelled"
+
+
+class JobType(str, Enum):
+ TRIAGE = "triage"
+ HOST_PROFILE = "host_profile"
+ REPORT = "report"
+ ANOMALY = "anomaly"
+ QUERY = "query"
+
+
+@dataclass
+class Job:
+ id: str
+ job_type: JobType
+ status: JobStatus = JobStatus.QUEUED
+ progress: float = 0.0 # 0-100
+ message: str = ""
+ result: Any = None
+ error: str | None = None
+ created_at: float = field(default_factory=time.time)
+ started_at: float | None = None
+ completed_at: float | None = None
+ params: dict = field(default_factory=dict)
+ _cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
+
+ @property
+ def elapsed_ms(self) -> int:
+ end = self.completed_at or time.time()
+ start = self.started_at or self.created_at
+ return int((end - start) * 1000)
+
+ def to_dict(self) -> dict:
+ return {
+ "id": self.id,
+ "job_type": self.job_type.value,
+ "status": self.status.value,
+ "progress": round(self.progress, 1),
+ "message": self.message,
+ "error": self.error,
+ "created_at": self.created_at,
+ "started_at": self.started_at,
+ "completed_at": self.completed_at,
+ "elapsed_ms": self.elapsed_ms,
+ "params": self.params,
+ }
+
+ @property
+ def is_cancelled(self) -> bool:
+ return self._cancel_event.is_set()
+
+ def cancel(self):
+ self._cancel_event.set()
+ self.status = JobStatus.CANCELLED
+ self.completed_at = time.time()
+ self.message = "Cancelled by user"
+
+
+class JobQueue:
+ """In-memory async job queue with concurrency control.
+
+ Jobs are tracked by ID and can be listed, polled, or cancelled.
+ A configurable number of workers process jobs from the queue.
+ """
+
+ def __init__(self, max_workers: int = 3):
+ self._jobs: dict[str, Job] = {}
+ self._queue: asyncio.Queue[str] = asyncio.Queue()
+ self._max_workers = max_workers
+ self._workers: list[asyncio.Task] = []
+ self._handlers: dict[JobType, Callable] = {}
+ self._started = False
+
+ def register_handler(
+ self,
+ job_type: JobType,
+ handler: Callable[[Job], Coroutine],
+ ):
+ """Register an async handler for a job type.
+
+ Handler signature: async def handler(job: Job) -> Any
+ The handler can update job.progress and job.message during execution.
+ It should check job.is_cancelled periodically and return early.
+ """
+ self._handlers[job_type] = handler
+ logger.info(f"Registered handler for {job_type.value}")
+
+ async def start(self):
+ """Start worker tasks."""
+ if self._started:
+ return
+ self._started = True
+ for i in range(self._max_workers):
+ task = asyncio.create_task(self._worker(i))
+ self._workers.append(task)
+ logger.info(f"Job queue started with {self._max_workers} workers")
+
+ async def stop(self):
+ """Stop all workers."""
+ self._started = False
+ for w in self._workers:
+ w.cancel()
+ await asyncio.gather(*self._workers, return_exceptions=True)
+ self._workers.clear()
+ logger.info("Job queue stopped")
+
+ def submit(self, job_type: JobType, **params) -> Job:
+ """Submit a new job. Returns the Job object immediately."""
+ job = Job(
+ id=str(uuid.uuid4()),
+ job_type=job_type,
+ params=params,
+ )
+ self._jobs[job.id] = job
+ self._queue.put_nowait(job.id)
+ logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
+ return job
+
+ def get_job(self, job_id: str) -> Job | None:
+ return self._jobs.get(job_id)
+
+ def cancel_job(self, job_id: str) -> bool:
+ job = self._jobs.get(job_id)
+ if not job:
+ return False
+ if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
+ return False
+ job.cancel()
+ return True
+
+ def list_jobs(
+ self,
+ status: JobStatus | None = None,
+ job_type: JobType | None = None,
+ limit: int = 50,
+ ) -> list[dict]:
+ """List jobs, newest first."""
+ jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
+ if status:
+ jobs = [j for j in jobs if j.status == status]
+ if job_type:
+ jobs = [j for j in jobs if j.job_type == job_type]
+ return [j.to_dict() for j in jobs[:limit]]
+
+ def get_stats(self) -> dict:
+ """Get queue statistics."""
+ by_status = {}
+ for j in self._jobs.values():
+ by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
+ return {
+ "total": len(self._jobs),
+ "queued": self._queue.qsize(),
+ "by_status": by_status,
+ "workers": self._max_workers,
+ "active_workers": sum(
+ 1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
+ ),
+ }
+
+ def cleanup(self, max_age_seconds: float = 3600):
+ """Remove old completed/failed/cancelled jobs."""
+ now = time.time()
+ to_remove = [
+ jid for jid, j in self._jobs.items()
+ if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
+ and (now - j.created_at) > max_age_seconds
+ ]
+ for jid in to_remove:
+ del self._jobs[jid]
+ if to_remove:
+ logger.info(f"Cleaned up {len(to_remove)} old jobs")
+
+ async def _worker(self, worker_id: int):
+ """Worker loop: pull jobs from queue and execute handlers."""
+ logger.info(f"Worker {worker_id} started")
+ while self._started:
+ try:
+ job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0)
+ except asyncio.TimeoutError:
+ continue
+ except asyncio.CancelledError:
+ break
+
+ job = self._jobs.get(job_id)
+ if not job or job.is_cancelled:
+ continue
+
+ handler = self._handlers.get(job.job_type)
+ if not handler:
+ job.status = JobStatus.FAILED
+ job.error = f"No handler for {job.job_type.value}"
+ job.completed_at = time.time()
+ logger.error(f"No handler for job type {job.job_type.value}")
+ continue
+
+ job.status = JobStatus.RUNNING
+ job.started_at = time.time()
+ job.message = "Running..."
+ logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
+
+ try:
+ result = await handler(job)
+ if not job.is_cancelled:
+ job.status = JobStatus.COMPLETED
+ job.progress = 100.0
+ job.result = result
+ job.message = "Completed"
+ job.completed_at = time.time()
+ logger.info(
+ f"Worker {worker_id}: completed {job.id} "
+ f"in {job.elapsed_ms}ms"
+ )
+ except Exception as e:
+ if not job.is_cancelled:
+ job.status = JobStatus.FAILED
+ job.error = str(e)
+ job.message = f"Failed: {e}"
+ job.completed_at = time.time()
+ logger.error(
+ f"Worker {worker_id}: failed {job.id}: {e}",
+ exc_info=True,
+ )
+
+
+# Singleton + job handlers
+
+job_queue = JobQueue(max_workers=3)
+
+
+async def _handle_triage(job: Job):
+ """Triage handler."""
+ from app.services.triage import triage_dataset
+ dataset_id = job.params.get("dataset_id")
+ job.message = f"Triaging dataset {dataset_id}"
+ results = await triage_dataset(dataset_id)
+ return {"count": len(results) if results else 0}
+
+
+async def _handle_host_profile(job: Job):
+ """Host profiling handler."""
+ from app.services.host_profiler import profile_all_hosts, profile_host
+ hunt_id = job.params.get("hunt_id")
+ hostname = job.params.get("hostname")
+ if hostname:
+ job.message = f"Profiling host {hostname}"
+ await profile_host(hunt_id, hostname)
+ return {"hostname": hostname}
+ else:
+ job.message = f"Profiling all hosts in hunt {hunt_id}"
+ await profile_all_hosts(hunt_id)
+ return {"hunt_id": hunt_id}
+
+
+async def _handle_report(job: Job):
+ """Report generation handler."""
+ from app.services.report_generator import generate_report
+ hunt_id = job.params.get("hunt_id")
+ job.message = f"Generating report for hunt {hunt_id}"
+ report = await generate_report(hunt_id)
+ return {"report_id": report.id if report else None}
+
+
+async def _handle_anomaly(job: Job):
+ """Anomaly detection handler."""
+ from app.services.anomaly_detector import detect_anomalies
+ dataset_id = job.params.get("dataset_id")
+ k = job.params.get("k", 3)
+ threshold = job.params.get("threshold", 0.35)
+ job.message = f"Detecting anomalies in dataset {dataset_id}"
+ results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
+ return {"count": len(results) if results else 0}
+
+
+async def _handle_query(job: Job):
+ """Data query handler (non-streaming)."""
+ from app.services.data_query import query_dataset
+ dataset_id = job.params.get("dataset_id")
+ question = job.params.get("question", "")
+ mode = job.params.get("mode", "quick")
+ job.message = f"Querying dataset {dataset_id}"
+ answer = await query_dataset(dataset_id, question, mode)
+ return {"answer": answer}
+
+
+def register_all_handlers():
+ """Register all job handlers."""
+ job_queue.register_handler(JobType.TRIAGE, _handle_triage)
+ job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
+ job_queue.register_handler(JobType.REPORT, _handle_report)
+ job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
+ job_queue.register_handler(JobType.QUERY, _handle_query)
\ No newline at end of file
diff --git a/backend/app/services/load_balancer.py b/backend/app/services/load_balancer.py
new file mode 100644
index 0000000..3ccc57c
--- /dev/null
+++ b/backend/app/services/load_balancer.py
@@ -0,0 +1,193 @@
+"""Smart load balancer for Wile & Roadrunner LLM nodes.
+
+Tracks active jobs per node, health status, and routes new work
+to the least-busy healthy node. Periodically pings both nodes
+to maintain an up-to-date health map.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import logging
+import time
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Optional
+
+from app.config import settings
+
+logger = logging.getLogger(__name__)
+
+
+class NodeId(str, Enum):
+ WILE = "wile"
+ ROADRUNNER = "roadrunner"
+
+
+class WorkloadTier(str, Enum):
+ """What kind of workload is this?"""
+ HEAVY = "heavy" # 70B models, deep analysis, reports
+ FAST = "fast" # 7-14B models, triage, quick queries
+ EMBEDDING = "embed" # bge-m3 embeddings
+ ANY = "any"
+
+
+@dataclass
+class NodeStatus:
+ node_id: NodeId
+ url: str
+ healthy: bool = True
+ last_check: float = 0.0
+ active_jobs: int = 0
+ total_completed: int = 0
+ total_errors: int = 0
+ avg_latency_ms: float = 0.0
+ _latencies: list[float] = field(default_factory=list)
+
+ def record_completion(self, latency_ms: float):
+ self.active_jobs = max(0, self.active_jobs - 1)
+ self.total_completed += 1
+ self._latencies.append(latency_ms)
+ # Rolling average of last 50
+ if len(self._latencies) > 50:
+ self._latencies = self._latencies[-50:]
+ self.avg_latency_ms = sum(self._latencies) / len(self._latencies)
+
+ def record_error(self):
+ self.active_jobs = max(0, self.active_jobs - 1)
+ self.total_errors += 1
+
+ def record_start(self):
+ self.active_jobs += 1
+
+
+class LoadBalancer:
+ """Routes LLM work to the least-busy healthy node.
+
+ Node capabilities:
+ - Wile: Heavy models (70B), code models (32B)
+ - Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision
+ """
+
+ # Which nodes can handle which tiers
+ TIER_NODES = {
+ WorkloadTier.HEAVY: [NodeId.WILE],
+ WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE],
+ WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER],
+ WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE],
+ }
+
+ def __init__(self):
+ self._nodes: dict[NodeId, NodeStatus] = {
+ NodeId.WILE: NodeStatus(
+ node_id=NodeId.WILE,
+ url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}",
+ ),
+ NodeId.ROADRUNNER: NodeStatus(
+ node_id=NodeId.ROADRUNNER,
+ url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}",
+ ),
+ }
+ self._lock = asyncio.Lock()
+ self._health_task: Optional[asyncio.Task] = None
+
+ async def start_health_loop(self, interval: float = 30.0):
+ """Start background health-check loop."""
+ if self._health_task and not self._health_task.done():
+ return
+ self._health_task = asyncio.create_task(self._health_loop(interval))
+ logger.info("Load balancer health loop started (%.0fs interval)", interval)
+
+ async def stop_health_loop(self):
+ if self._health_task:
+ self._health_task.cancel()
+ try:
+ await self._health_task
+ except asyncio.CancelledError:
+ pass
+ self._health_task = None
+
+ async def _health_loop(self, interval: float):
+ while True:
+ try:
+ await self.check_health()
+ except Exception as e:
+ logger.warning(f"Health check error: {e}")
+ await asyncio.sleep(interval)
+
+ async def check_health(self):
+ """Ping both nodes and update status."""
+ import httpx
+ async with httpx.AsyncClient(timeout=5) as client:
+ for nid, status in self._nodes.items():
+ try:
+ resp = await client.get(f"{status.url}/api/tags")
+ status.healthy = resp.status_code == 200
+ except Exception:
+ status.healthy = False
+ status.last_check = time.time()
+ logger.debug(
+ f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} "
+ f"(active={status.active_jobs})"
+ )
+
+ def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
+ """Select the best node for a workload tier.
+
+ Strategy: among healthy nodes that support the tier,
+ pick the one with fewest active jobs.
+ Falls back to any node if none healthy.
+ """
+ candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE])
+
+ # Filter to healthy candidates
+ healthy = [
+ nid for nid in candidates
+ if self._nodes[nid].healthy
+ ]
+
+ if not healthy:
+ logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate")
+ healthy = candidates
+
+ # Pick least busy
+ best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs)
+ return best
+
+ def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
+ """Select node and mark a job started."""
+ node = self.select_node(tier)
+ self._nodes[node].record_start()
+ logger.info(
+ f"LB: dispatched {tier.value} -> {node.value} "
+ f"(active={self._nodes[node].active_jobs})"
+ )
+ return node
+
+ def release(self, node: NodeId, latency_ms: float = 0, error: bool = False):
+ """Mark a job completed on a node."""
+ status = self._nodes.get(node)
+ if not status:
+ return
+ if error:
+ status.record_error()
+ else:
+ status.record_completion(latency_ms)
+
+ def get_status(self) -> dict:
+ """Get current load balancer status."""
+ return {
+ nid.value: {
+ "healthy": s.healthy,
+ "active_jobs": s.active_jobs,
+ "total_completed": s.total_completed,
+ "total_errors": s.total_errors,
+ "avg_latency_ms": round(s.avg_latency_ms, 1),
+ "last_check": s.last_check,
+ }
+ for nid, s in self._nodes.items()
+ }
+
+
+# Singleton
+lb = LoadBalancer()
\ No newline at end of file
diff --git a/backend/app/services/report_generator.py b/backend/app/services/report_generator.py
new file mode 100644
index 0000000..ef0d57e
--- /dev/null
+++ b/backend/app/services/report_generator.py
@@ -0,0 +1,198 @@
+"""Report generator - debate-powered hunt report generation using Wile + Roadrunner."""
+
+from __future__ import annotations
+
+import json
+import logging
+import time
+
+import httpx
+from sqlalchemy import select
+
+from app.config import settings
+from app.db.engine import async_session
+from app.db.models import (
+ Dataset, HostProfile, HuntReport, TriageResult,
+)
+from app.services.triage import _parse_llm_response
+
+logger = logging.getLogger(__name__)
+
+WILE_URL = f"{settings.wile_url}/api/generate"
+ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
+HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
+FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
+
+
+async def _llm_call(url: str, model: str, system: str, prompt: str, timeout: float = 300.0) -> str:
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ resp = await client.post(
+ url,
+ json={
+ "model": model,
+ "prompt": prompt,
+ "system": system,
+ "stream": False,
+ "options": {"temperature": 0.3, "num_predict": 8192},
+ },
+ )
+ resp.raise_for_status()
+ return resp.json().get("response", "")
+
+
+async def _gather_evidence(db, hunt_id: str) -> dict:
+ ds_result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
+ datasets = ds_result.scalars().all()
+
+ dataset_summary = []
+ all_triage = []
+ for ds in datasets:
+ ds_info = {
+ "name": ds.name,
+ "artifact_type": getattr(ds, "artifact_type", "Unknown"),
+ "row_count": ds.row_count or 0,
+ }
+ dataset_summary.append(ds_info)
+
+ triage_result = await db.execute(
+ select(TriageResult)
+ .where(TriageResult.dataset_id == ds.id)
+ .where(TriageResult.risk_score >= 3.0)
+ .order_by(TriageResult.risk_score.desc())
+ .limit(15)
+ )
+ for t in triage_result.scalars().all():
+ all_triage.append({
+ "dataset": ds.name,
+ "artifact_type": ds_info["artifact_type"],
+ "rows": f"{t.row_start}-{t.row_end}",
+ "risk_score": t.risk_score,
+ "verdict": t.verdict,
+ "findings": t.findings[:5] if t.findings else [],
+ "indicators": t.suspicious_indicators[:5] if t.suspicious_indicators else [],
+ "mitre": t.mitre_techniques or [],
+ })
+
+ profile_result = await db.execute(
+ select(HostProfile)
+ .where(HostProfile.hunt_id == hunt_id)
+ .order_by(HostProfile.risk_score.desc())
+ )
+ profiles = profile_result.scalars().all()
+ host_summaries = []
+ for p in profiles:
+ host_summaries.append({
+ "hostname": p.hostname,
+ "risk_score": p.risk_score,
+ "risk_level": p.risk_level,
+ "findings": p.suspicious_findings[:5] if p.suspicious_findings else [],
+ "mitre": p.mitre_techniques or [],
+ "timeline": (p.timeline_summary or "")[:300],
+ })
+
+ return {
+ "datasets": dataset_summary,
+ "triage_findings": all_triage[:30],
+ "host_profiles": host_summaries,
+ "total_datasets": len(datasets),
+ "total_rows": sum(d["row_count"] for d in dataset_summary),
+ "high_risk_hosts": len([h for h in host_summaries if h["risk_score"] >= 7.0]),
+ }
+
+
+async def generate_report(hunt_id: str) -> None:
+ logger.info("Generating report for hunt %s", hunt_id)
+ start = time.monotonic()
+
+ async with async_session() as db:
+ report = HuntReport(
+ hunt_id=hunt_id,
+ status="generating",
+ models_used=[HEAVY_MODEL, FAST_MODEL],
+ )
+ db.add(report)
+ await db.commit()
+ await db.refresh(report)
+ report_id = report.id
+
+ try:
+ evidence = await _gather_evidence(db, hunt_id)
+ evidence_text = json.dumps(evidence, indent=1, default=str)[:12000]
+
+ # Phase 1: Wile initial analysis
+ logger.info("Report phase 1: Wile initial analysis")
+ phase1 = await _llm_call(
+ WILE_URL, HEAVY_MODEL,
+ system=(
+ "You are a senior threat intelligence analyst writing a hunt report.\n"
+ "Analyze all evidence and produce a structured threat assessment.\n"
+ "Include: executive summary, detailed findings per host, MITRE mapping,\n"
+ "IOC table, risk rankings, and actionable recommendations.\n"
+ "Use markdown formatting. Be thorough and specific."
+ ),
+ prompt=f"Hunt evidence:\n{evidence_text}\n\nProduce your initial threat assessment.",
+ )
+
+ # Phase 2: Roadrunner critical review
+ logger.info("Report phase 2: Roadrunner critical review")
+ phase2 = await _llm_call(
+ ROADRUNNER_URL, FAST_MODEL,
+ system=(
+ "You are a critical reviewer of threat hunt reports.\n"
+ "Review the initial assessment and identify:\n"
+ "- Missing correlations or overlooked indicators\n"
+ "- False positive risks or overblown findings\n"
+ "- Additional MITRE techniques that should be mapped\n"
+ "- Gaps in recommendations\n"
+ "Be specific and constructive. Respond in markdown."
+ ),
+ prompt=f"Evidence:\n{evidence_text[:4000]}\n\nInitial Assessment:\n{phase1[:6000]}\n\nProvide your critical review.",
+ timeout=120.0,
+ )
+
+ # Phase 3: Wile final synthesis
+ logger.info("Report phase 3: Wile final synthesis")
+ synthesis_prompt = (
+ f"Original evidence:\n{evidence_text[:6000]}\n\n"
+ f"Initial assessment:\n{phase1[:5000]}\n\n"
+ f"Critical review:\n{phase2[:3000]}\n\n"
+ "Produce the FINAL hunt report incorporating the review feedback.\n"
+ "Return JSON with these keys:\n"
+ "- executive_summary: 2-3 paragraph executive summary\n"
+ "- findings: list of {title, severity, description, evidence, mitre_ids}\n"
+ "- recommendations: list of {priority, action, rationale}\n"
+ "- mitre_mapping: dict of technique_id -> {name, description, evidence}\n"
+ "- ioc_table: list of {type, value, context, confidence}\n"
+ "- host_risk_summary: list of {hostname, risk_score, risk_level, key_findings}\n"
+ "Respond with valid JSON only."
+ )
+ phase3_text = await _llm_call(
+ WILE_URL, HEAVY_MODEL,
+ system="You are producing the final, definitive threat hunt report. Incorporate all feedback. Respond with valid JSON only.",
+ prompt=synthesis_prompt,
+ )
+
+ parsed = _parse_llm_response(phase3_text)
+ elapsed_ms = int((time.monotonic() - start) * 1000)
+
+ full_report = f"# Threat Hunt Report\n\n{phase1}\n\n---\n## Review Notes\n{phase2}\n\n---\n## Final Synthesis\n{phase3_text}"
+
+ report.status = "complete"
+ report.exec_summary = parsed.get("executive_summary", phase1[:2000])
+ report.full_report = full_report
+ report.findings = parsed.get("findings", [])
+ report.recommendations = parsed.get("recommendations", [])
+ report.mitre_mapping = parsed.get("mitre_mapping", {})
+ report.ioc_table = parsed.get("ioc_table", [])
+ report.host_risk_summary = parsed.get("host_risk_summary", [])
+ report.generation_time_ms = elapsed_ms
+ await db.commit()
+
+ logger.info("Report %s complete in %dms", report_id, elapsed_ms)
+
+ except Exception as e:
+ logger.error("Report generation failed for hunt %s: %s", hunt_id, e)
+ report.status = "error"
+ report.exec_summary = f"Report generation failed: {e}"
+ report.generation_time_ms = int((time.monotonic() - start) * 1000)
+ await db.commit()
\ No newline at end of file
diff --git a/backend/app/services/triage.py b/backend/app/services/triage.py
new file mode 100644
index 0000000..09be225
--- /dev/null
+++ b/backend/app/services/triage.py
@@ -0,0 +1,170 @@
+"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
+
+from __future__ import annotations
+
+import json
+import logging
+import re
+
+import httpx
+from sqlalchemy import func, select
+
+from app.config import settings
+from app.db.engine import async_session
+from app.db.models import Dataset, DatasetRow, TriageResult
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
+ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
+
+ARTIFACT_FOCUS = {
+ "Windows.System.Pslist": "Look for: suspicious parent-child, LOLBins, unsigned, injection indicators, abnormal paths.",
+ "Windows.Network.Netstat": "Look for: C2 beaconing, unusual ports, connections to rare IPs, non-browser high-port listeners.",
+ "Windows.System.Services": "Look for: services in temp dirs, misspelled names, unsigned ServiceDll, unusual start modes.",
+ "Windows.Forensics.Prefetch": "Look for: recon tools, lateral movement tools, rarely-run executables with high run counts.",
+ "Windows.EventLogs.EvtxHunter": "Look for: logon type 10/3 anomalies, service installs, PowerShell script blocks, clearing.",
+ "Windows.Sys.Autoruns": "Look for: recently added entries, entries in temp/user dirs, encoded commands, suspicious DLLs.",
+ "Windows.Registry.Finder": "Look for: run keys, image file execution options, hidden services, encoded payloads.",
+ "Windows.Search.FileFinder": "Look for: files in unusual locations, recently modified system files, known tool names.",
+}
+
+
+def _parse_llm_response(text: str) -> dict:
+ text = text.strip()
+ fence = re.search(r"`(?:json)?\s*\n?(.*?)\n?\s*`", text, re.DOTALL)
+ if fence:
+ text = fence.group(1).strip()
+ try:
+ return json.loads(text)
+ except json.JSONDecodeError:
+ brace = text.find("{")
+ bracket = text.rfind("}")
+ if brace != -1 and bracket != -1 and bracket > brace:
+ try:
+ return json.loads(text[brace : bracket + 1])
+ except json.JSONDecodeError:
+ pass
+ return {"raw_response": text[:3000]}
+
+
+async def triage_dataset(dataset_id: str) -> None:
+ logger.info("Starting triage for dataset %s", dataset_id)
+
+ async with async_session() as db:
+ ds_result = await db.execute(
+ select(Dataset).where(Dataset.id == dataset_id)
+ )
+ dataset = ds_result.scalar_one_or_none()
+ if not dataset:
+ logger.error("Dataset %s not found", dataset_id)
+ return
+
+ artifact_type = getattr(dataset, "artifact_type", None) or "Unknown"
+ focus = ARTIFACT_FOCUS.get(artifact_type, "Analyze for any suspicious indicators.")
+
+ count_result = await db.execute(
+ select(func.count()).where(DatasetRow.dataset_id == dataset_id)
+ )
+ total_rows = count_result.scalar() or 0
+
+ batch_size = settings.TRIAGE_BATCH_SIZE
+ suspicious_count = 0
+ offset = 0
+
+ while offset < total_rows:
+ if suspicious_count >= settings.TRIAGE_MAX_SUSPICIOUS_ROWS:
+ logger.info("Reached suspicious row cap for dataset %s", dataset_id)
+ break
+
+ rows_result = await db.execute(
+ select(DatasetRow)
+ .where(DatasetRow.dataset_id == dataset_id)
+ .order_by(DatasetRow.row_number)
+ .offset(offset)
+ .limit(batch_size)
+ )
+ rows = rows_result.scalars().all()
+ if not rows:
+ break
+
+ batch_data = []
+ for r in rows:
+ data = r.normalized_data or r.data
+ compact = {k: str(v)[:200] for k, v in data.items() if v}
+ batch_data.append(compact)
+
+ system_prompt = f"""You are a cybersecurity triage analyst. Analyze this batch of {artifact_type} forensic data.
+{focus}
+
+Return JSON with:
+- risk_score: 0.0 (benign) to 10.0 (critical threat)
+- verdict: "clean", "suspicious", "malicious", or "inconclusive"
+- findings: list of key observations
+- suspicious_indicators: list of specific IOCs or anomalies
+- mitre_techniques: list of MITRE ATT&CK IDs if applicable
+
+Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."""
+
+ prompt = f"Rows {offset+1}-{offset+len(rows)} of {total_rows}:\n{json.dumps(batch_data, default=str)[:6000]}"
+
+ try:
+ async with httpx.AsyncClient(timeout=120.0) as client:
+ resp = await client.post(
+ ROADRUNNER_URL,
+ json={
+ "model": DEFAULT_FAST_MODEL,
+ "prompt": prompt,
+ "system": system_prompt,
+ "stream": False,
+ "options": {"temperature": 0.2, "num_predict": 2048},
+ },
+ )
+ resp.raise_for_status()
+ result = resp.json()
+ llm_text = result.get("response", "")
+
+ parsed = _parse_llm_response(llm_text)
+ risk = float(parsed.get("risk_score", 0.0))
+
+ triage = TriageResult(
+ dataset_id=dataset_id,
+ row_start=offset,
+ row_end=offset + len(rows) - 1,
+ risk_score=risk,
+ verdict=parsed.get("verdict", "inconclusive"),
+ findings=parsed.get("findings", []),
+ suspicious_indicators=parsed.get("suspicious_indicators", []),
+ mitre_techniques=parsed.get("mitre_techniques", []),
+ model_used=DEFAULT_FAST_MODEL,
+ node_used="roadrunner",
+ )
+ db.add(triage)
+ await db.commit()
+
+ if risk >= settings.TRIAGE_ESCALATION_THRESHOLD:
+ suspicious_count += len(rows)
+
+ logger.debug(
+ "Triage batch %d-%d: risk=%.1f verdict=%s",
+ offset, offset + len(rows) - 1, risk, triage.verdict,
+ )
+
+ except Exception as e:
+ logger.error("Triage batch %d failed: %s", offset, e)
+ triage = TriageResult(
+ dataset_id=dataset_id,
+ row_start=offset,
+ row_end=offset + len(rows) - 1,
+ risk_score=0.0,
+ verdict="error",
+ findings=[f"Error: {e}"],
+ model_used=DEFAULT_FAST_MODEL,
+ node_used="roadrunner",
+ )
+ db.add(triage)
+ await db.commit()
+
+ offset += batch_size
+
+ logger.info("Triage complete for dataset %s", dataset_id)
\ No newline at end of file
diff --git a/backend/run.py b/backend/run.py
index c8fa0b6..1bb8002 100644
--- a/backend/run.py
+++ b/backend/run.py
@@ -13,5 +13,5 @@ if __name__ == "__main__":
"app.main:app",
host="0.0.0.0",
port=8000,
- reload=True,
- )
+ reload=False,
+ )
\ No newline at end of file
diff --git a/frontend/nginx.conf b/frontend/nginx.conf
index 959445f..432f927 100644
--- a/frontend/nginx.conf
+++ b/frontend/nginx.conf
@@ -15,10 +15,10 @@ server {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
- proxy_read_timeout 120s;
+ proxy_read_timeout 300s;
}
- # SPA fallback — serve index.html for all non-file routes
+ # SPA fallback serve index.html for all non-file routes
location / {
try_files $uri $uri/ /index.html;
}
@@ -28,4 +28,4 @@ server {
expires 1y;
add_header Cache-Control "public, immutable";
}
-}
+}
\ No newline at end of file
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index 4ec96b1..bf3df36 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -1,5 +1,5 @@
/**
- * ThreatHunt — MUI-powered analyst-assist platform.
+ * ThreatHunt MUI-powered analyst-assist platform.
*/
import React, { useState, useCallback } from 'react';
@@ -18,6 +18,7 @@ import ScienceIcon from '@mui/icons-material/Science';
import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
import HubIcon from '@mui/icons-material/Hub';
+import AssessmentIcon from '@mui/icons-material/Assessment';
import { SnackbarProvider } from 'notistack';
import theme from './theme';
@@ -32,6 +33,7 @@ import HypothesisTracker from './components/HypothesisTracker';
import CorrelationView from './components/CorrelationView';
import AUPScanner from './components/AUPScanner';
import NetworkMap from './components/NetworkMap';
+import AnalysisDashboard from './components/AnalysisDashboard';
const DRAWER_WIDTH = 240;
@@ -42,13 +44,14 @@ const NAV: NavItem[] = [
{ label: 'Hunts', path: '/hunts', icon: },
{ label: 'Datasets', path: '/datasets', icon: },
{ label: 'Upload', path: '/upload', icon: },
+ { label: 'AI Analysis', path: '/analysis', icon: },
{ label: 'Agent', path: '/agent', icon: },
{ label: 'Enrichment', path: '/enrichment', icon: },
{ label: 'Annotations', path: '/annotations', icon: },
{ label: 'Hypotheses', path: '/hypotheses', icon: },
{ label: 'Correlation', path: '/correlation', icon: },
- { label: 'Network Map', path: '/network', icon: },
- { label: 'AUP Scanner', path: '/aup', icon: },
+ { label: 'Network Map', path: '/network', icon: },
+ { label: 'AUP Scanner', path: '/aup', icon: },
];
function Shell() {
@@ -109,6 +112,7 @@ function Shell() {
} />
} />
} />
+ } />
} />
} />
} />
@@ -135,4 +139,4 @@ function App() {
);
}
-export default App;
+export default App;
\ No newline at end of file
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index d83cabd..37476ac 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -1,11 +1,11 @@
/* ====================================================================
- ThreatHunt API Client — mirrors every backend endpoint.
+ ThreatHunt API Client -- mirrors every backend endpoint.
All requests go through the CRA proxy (see package.json "proxy").
==================================================================== */
const BASE = ''; // proxied to http://localhost:8000 by CRA
-// ── Helpers ──────────────────────────────────────────────────────────
+// -- Helpers --
let authToken: string | null = localStorage.getItem('th_token');
@@ -36,7 +36,7 @@ async function api(
return res.text() as unknown as T;
}
-// ── Auth ─────────────────────────────────────────────────────────────
+// -- Auth --
export interface UserPayload {
id: string; username: string; email: string;
@@ -63,7 +63,7 @@ export const auth = {
me: () => api('/api/auth/me'),
};
-// ── Hunts ────────────────────────────────────────────────────────────
+// -- Hunts --
export interface Hunt {
id: string; name: string; description: string | null; status: string;
@@ -82,7 +82,7 @@ export const hunts = {
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
};
-// ── Datasets ─────────────────────────────────────────────────────────
+// -- Datasets --
export interface DatasetSummary {
id: string; name: string; filename: string; source_tool: string | null;
@@ -92,6 +92,8 @@ export interface DatasetSummary {
file_size_bytes: number; encoding: string | null; delimiter: string | null;
time_range_start: string | null; time_range_end: string | null;
hunt_id: string | null; created_at: string;
+ processing_status?: string; artifact_type?: string | null;
+ error_message?: string | null; file_path?: string | null;
}
export interface UploadResult {
@@ -155,7 +157,7 @@ export const datasets = {
delete: (id: string) => api(`/api/datasets/${id}`, { method: 'DELETE' }),
};
-// ── Agent ────────────────────────────────────────────────────────────
+// -- Agent --
export interface AssistRequest {
query: string;
@@ -198,7 +200,7 @@ export const agent = {
},
};
-// ── Annotations ──────────────────────────────────────────────────────
+// -- Annotations --
export interface AnnotationData {
id: string; row_id: number | null; dataset_id: string | null;
@@ -224,7 +226,7 @@ export const annotations = {
delete: (id: string) => api(`/api/annotations/${id}`, { method: 'DELETE' }),
};
-// ── Hypotheses ───────────────────────────────────────────────────────
+// -- Hypotheses --
export interface HypothesisData {
id: string; hunt_id: string | null; title: string; description: string | null;
@@ -249,7 +251,7 @@ export const hypotheses = {
delete: (id: string) => api(`/api/hypotheses/${id}`, { method: 'DELETE' }),
};
-// ── Enrichment ───────────────────────────────────────────────────────
+// -- Enrichment --
export interface EnrichmentResult {
ioc_value: string; ioc_type: string; source: string; verdict: string;
@@ -274,7 +276,7 @@ export const enrichment = {
status: () => api>('/api/enrichment/status'),
};
-// ── Correlation ──────────────────────────────────────────────────────
+// -- Correlation --
export interface CorrelationResult {
hunt_ids: string[]; summary: string; total_correlations: number;
@@ -292,7 +294,7 @@ export const correlation = {
api<{ ioc_value: string; occurrences: any[]; total: number }>(`/api/correlation/ioc/${encodeURIComponent(ioc_value)}`),
};
-// ── Reports ──────────────────────────────────────────────────────────
+// -- Reports --
export const reports = {
json: (huntId: string) =>
@@ -305,13 +307,13 @@ export const reports = {
api>(`/api/reports/hunt/${huntId}/summary`),
};
-// ── Root / misc ──────────────────────────────────────────────────────
+// -- Root / misc --
export const misc = {
root: () => api<{ name: string; version: string; status: string }>('/'),
};
-// ── AUP Keywords ─────────────────────────────────────────────────────
+// -- AUP Keywords --
export interface KeywordOut {
id: number; theme_id: string; value: string; is_regex: boolean; created_at: string;
@@ -368,3 +370,216 @@ export const keywords = {
quickScan: (datasetId: string) =>
api(`/api/keywords/scan/quick?dataset_id=${encodeURIComponent(datasetId)}`),
};
+
+
+// -- Analysis (Phase 2+) --
+
+export interface TriageResultData {
+ id: string; dataset_id: string; row_start: number; row_end: number;
+ risk_score: number; verdict: string;
+ findings: any[] | null; suspicious_indicators: any[] | null;
+ mitre_techniques: any[] | null;
+ model_used: string | null; node_used: string | null;
+}
+
+export interface HostProfileData {
+ id: string; hunt_id: string; hostname: string; fqdn: string | null;
+ risk_score: number; risk_level: string;
+ artifact_summary: Record | null;
+ timeline_summary: string | null;
+ suspicious_findings: any[] | null;
+ mitre_techniques: any[] | null;
+ llm_analysis: string | null;
+ model_used: string | null;
+}
+
+export interface HuntReportData {
+ id: string; hunt_id: string; status: string;
+ exec_summary: string | null; full_report: string | null;
+ findings: any[] | null; recommendations: any[] | null;
+ mitre_mapping: Record | null;
+ ioc_table: any[] | null; host_risk_summary: any[] | null;
+ models_used: any[] | null; generation_time_ms: number | null;
+}
+
+export interface AnomalyResultData {
+ id: string; dataset_id: string; row_id: number | null;
+ anomaly_score: number; distance_from_centroid: number | null;
+ cluster_id: number | null; is_outlier: boolean;
+ explanation: string | null;
+}
+
+export interface HostGroupData {
+ hostname: string;
+ dataset_count: number;
+ total_rows: number;
+ artifact_types: string[];
+ first_seen: string | null;
+ last_seen: string | null;
+ risk_score: number | null;
+}
+
+// -- Job queue types (Phase 10) --
+
+export interface JobData {
+ id: string;
+ job_type: string;
+ status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled';
+ progress: number;
+ message: string;
+ error: string | null;
+ created_at: number;
+ started_at: number | null;
+ completed_at: number | null;
+ elapsed_ms: number;
+ params: Record;
+}
+
+export interface JobStats {
+ total: number;
+ queued: number;
+ by_status: Record;
+ workers: number;
+ active_workers: number;
+}
+
+export interface LBNodeStatus {
+ healthy: boolean;
+ active_jobs: number;
+ total_completed: number;
+ total_errors: number;
+ avg_latency_ms: number;
+ last_check: number;
+}
+
+export const analysis = {
+ // Triage
+ triageResults: (datasetId: string, minRisk = 0) =>
+ api(`/api/analysis/triage/${datasetId}?min_risk=${minRisk}`),
+ triggerTriage: (datasetId: string) =>
+ api<{ status: string; dataset_id: string }>(`/api/analysis/triage/${datasetId}`, { method: 'POST' }),
+
+ // Host profiles
+ hostProfiles: (huntId: string, minRisk = 0) =>
+ api(`/api/analysis/profiles/${huntId}?min_risk=${minRisk}`),
+ triggerAllProfiles: (huntId: string) =>
+ api<{ status: string; hunt_id: string }>(`/api/analysis/profiles/${huntId}`, { method: 'POST' }),
+ triggerHostProfile: (huntId: string, hostname: string) =>
+ api<{ status: string }>(`/api/analysis/profiles/${huntId}/${encodeURIComponent(hostname)}`, { method: 'POST' }),
+
+ // Reports
+ listReports: (huntId: string) =>
+ api(`/api/analysis/reports/${huntId}`),
+ getReport: (huntId: string, reportId: string) =>
+ api(`/api/analysis/reports/${huntId}/${reportId}`),
+ generateReport: (huntId: string) =>
+ api<{ status: string; hunt_id: string }>(`/api/analysis/reports/${huntId}/generate`, { method: 'POST' }),
+
+ // Anomaly detection
+ anomalies: (datasetId: string, outliersOnly = false) =>
+ api(`/api/analysis/anomalies/${datasetId}?outliers_only=${outliersOnly}`),
+ triggerAnomalyDetection: (datasetId: string, k = 3, threshold = 0.35) =>
+ api<{ status: string; dataset_id: string }>(
+ `/api/analysis/anomalies/${datasetId}?k=${k}&threshold=${threshold}`, { method: 'POST' },
+ ),
+
+ // IOC extraction
+ extractIocs: (datasetId: string) =>
+ api<{ dataset_id: string; iocs: Record; total: number }>(
+ `/api/analysis/iocs/${datasetId}`,
+ ),
+
+ // Host grouping
+ hostGroups: (huntId: string) =>
+ api<{ hunt_id: string; hosts: HostGroupData[] }>(
+ `/api/analysis/hosts/${huntId}`,
+ ),
+
+ // Data query (Phase 9) - SSE streaming
+ queryStream: async (datasetId: string, question: string, mode: string = 'quick'): Promise => {
+ const headers: Record = { 'Content-Type': 'application/json' };
+ if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
+ return fetch(`${BASE}/api/analysis/query/${datasetId}`, {
+ method: 'POST',
+ headers,
+ body: JSON.stringify({ question, mode }),
+ });
+ },
+
+ // Data query (Phase 9) - sync
+ querySync: (datasetId: string, question: string, mode: string = 'quick') =>
+ api<{ dataset_id: string; question: string; answer: string; mode: string }>(
+ `/api/analysis/query/${datasetId}/sync`, {
+ method: 'POST',
+ body: JSON.stringify({ question, mode }),
+ },
+ ),
+
+ // Job queue (Phase 10)
+ listJobs: (status?: string, jobType?: string, limit = 50) => {
+ const q = new URLSearchParams();
+ if (status) q.set('status', status);
+ if (jobType) q.set('job_type', jobType);
+ q.set('limit', String(limit));
+ return api<{ jobs: JobData[]; stats: JobStats }>(`/api/analysis/jobs?${q}`);
+ },
+ getJob: (jobId: string) =>
+ api(`/api/analysis/jobs/${jobId}`),
+ cancelJob: (jobId: string) =>
+ api<{ status: string; job_id: string }>(`/api/analysis/jobs/${jobId}`, { method: 'DELETE' }),
+ submitJob: (jobType: string, params: Record = {}) =>
+ api<{ job_id: string; status: string; job_type: string }>(
+ `/api/analysis/jobs/submit/${jobType}`, {
+ method: 'POST',
+ body: JSON.stringify(params),
+ },
+ ),
+
+ // Load balancer (Phase 10)
+ lbStatus: () =>
+ api>('/api/analysis/lb/status'),
+ lbCheck: () =>
+ api>('/api/analysis/lb/check', { method: 'POST' }),
+};
+
+// -- Network Topology --
+
+export interface InventoryHost {
+ id: string;
+ hostname: string;
+ fqdn: string;
+ client_id: string;
+ ips: string[];
+ os: string;
+ users: string[];
+ datasets: string[];
+ row_count: number;
+}
+
+export interface InventoryConnection {
+ source: string;
+ target: string;
+ target_ip: string;
+ port: string;
+ count: number;
+}
+
+export interface InventoryStats {
+ total_hosts: number;
+ total_datasets_scanned: number;
+ datasets_with_hosts: number;
+ total_rows_scanned: number;
+ hosts_with_ips: number;
+ hosts_with_users: number;
+}
+
+export interface HostInventory {
+ hosts: InventoryHost[];
+ connections: InventoryConnection[];
+ stats: InventoryStats;
+}
+
+export const network = {
+ hostInventory: (huntId: string) =>
+ api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}`),
+};
\ No newline at end of file
diff --git a/frontend/src/components/AnalysisDashboard.tsx b/frontend/src/components/AnalysisDashboard.tsx
new file mode 100644
index 0000000..f79253d
--- /dev/null
+++ b/frontend/src/components/AnalysisDashboard.tsx
@@ -0,0 +1,818 @@
+/**
+ * AnalysisDashboard -- 6-tab view covering the full AI analysis pipeline:
+ * 1. Triage results
+ * 2. Host profiles
+ * 3. Reports
+ * 4. Anomalies
+ * 5. Ask Data (natural language query with SSE streaming) -- Phase 9
+ * 6. Jobs & Load Balancer status -- Phase 10
+ */
+
+import React, { useState, useEffect, useCallback, useRef } from 'react';
+import {
+ Box, Typography, Tabs, Tab, Paper, Button, Chip, Stack, CircularProgress,
+ Table, TableBody, TableCell, TableContainer, TableHead, TableRow,
+ Accordion, AccordionSummary, AccordionDetails, Alert, Select, MenuItem,
+ FormControl, InputLabel, LinearProgress, Tooltip, IconButton, Divider,
+ Card, CardContent, CardActions, Grid, TextField, ToggleButton,
+ ToggleButtonGroup,
+} from '@mui/material';
+import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
+import PlayArrowIcon from '@mui/icons-material/PlayArrow';
+import RefreshIcon from '@mui/icons-material/Refresh';
+import AssessmentIcon from '@mui/icons-material/Assessment';
+import SecurityIcon from '@mui/icons-material/Security';
+import PersonIcon from '@mui/icons-material/Person';
+import WarningAmberIcon from '@mui/icons-material/WarningAmber';
+import ShieldIcon from '@mui/icons-material/Shield';
+import BubbleChartIcon from '@mui/icons-material/BubbleChart';
+import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer';
+import WorkIcon from '@mui/icons-material/Work';
+import SendIcon from '@mui/icons-material/Send';
+import StopIcon from '@mui/icons-material/Stop';
+import DeleteIcon from '@mui/icons-material/Delete';
+import CheckCircleIcon from '@mui/icons-material/CheckCircle';
+import ErrorIcon from '@mui/icons-material/Error';
+import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
+import CancelIcon from '@mui/icons-material/Cancel';
+import { useSnackbar } from 'notistack';
+import {
+ analysis, hunts, datasets,
+ type Hunt, type DatasetSummary,
+ type TriageResultData, type HostProfileData, type HuntReportData,
+ type AnomalyResultData, type JobData, type JobStats, type LBNodeStatus,
+} from '../api/client';
+
+/* helpers */
+
+function riskColor(score: number): 'error' | 'warning' | 'info' | 'success' | 'default' {
+ if (score >= 8) return 'error';
+ if (score >= 5) return 'warning';
+ if (score >= 2) return 'info';
+ return 'success';
+}
+
+function riskLabel(level: string): 'error' | 'warning' | 'info' | 'success' | 'default' {
+ if (level === 'critical' || level === 'high') return 'error';
+ if (level === 'medium') return 'warning';
+ if (level === 'low') return 'success';
+ return 'default';
+}
+
+function fmtMs(ms: number): string {
+ if (ms < 1000) return `${ms}ms`;
+ return `${(ms / 1000).toFixed(1)}s`;
+}
+
+function fmtTime(ts: number): string {
+ if (!ts) return '--';
+ return new Date(ts * 1000).toLocaleTimeString();
+}
+
+const statusIcon = (s: string) => {
+ switch (s) {
+ case 'completed': return ;
+ case 'failed': return ;
+ case 'running': return ;
+ case 'queued': return ;
+ case 'cancelled': return ;
+ default: return null;
+ }
+};
+
+/* TabPanel */
+
+function TabPanel({ children, value, index }: { children: React.ReactNode; value: number; index: number }) {
+ return value === index ? {children} : null;
+}
+
+/* Main component */
+
+export default function AnalysisDashboard() {
+ const { enqueueSnackbar } = useSnackbar();
+ const [tab, setTab] = useState(0);
+
+ // Selectors
+ const [huntList, setHuntList] = useState([]);
+ const [dsList, setDsList] = useState([]);
+ const [huntId, setHuntId] = useState('');
+ const [dsId, setDsId] = useState('');
+
+ // Data tabs 0-3
+ const [triageResults, setTriageResults] = useState([]);
+ const [profiles, setProfiles] = useState([]);
+ const [reports, setReports] = useState([]);
+ const [anomalies, setAnomalies] = useState([]);
+
+ // Loading states
+ const [loadingTriage, setLoadingTriage] = useState(false);
+ const [loadingProfiles, setLoadingProfiles] = useState(false);
+ const [loadingReports, setLoadingReports] = useState(false);
+ const [loadingAnomalies, setLoadingAnomalies] = useState(false);
+ const [triggering, setTriggering] = useState(false);
+
+ // Phase 9: Ask Data
+ const [queryText, setQueryText] = useState('');
+ const [queryMode, setQueryMode] = useState('quick');
+ const [queryAnswer, setQueryAnswer] = useState('');
+ const [queryStreaming, setQueryStreaming] = useState(false);
+ const [queryMeta, setQueryMeta] = useState | null>(null);
+ const [queryDone, setQueryDone] = useState | null>(null);
+ const abortRef = useRef(null);
+ const answerRef = useRef(null);
+
+ // Phase 10: Jobs
+ const [jobs, setJobs] = useState([]);
+ const [jobStats, setJobStats] = useState(null);
+ const [lbStatus, setLbStatus] = useState | null>(null);
+ const [loadingJobs, setLoadingJobs] = useState(false);
+
+ // Load hunts and datasets
+ useEffect(() => {
+ hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {});
+ datasets.list(0, 200).then(r => setDsList(r.datasets)).catch(() => {});
+ }, []);
+
+ useEffect(() => {
+ if (huntList.length > 0 && !huntId) setHuntId(huntList[0].id);
+ }, [huntList, huntId]);
+
+ useEffect(() => {
+ if (dsList.length > 0 && !dsId) setDsId(dsList[0].id);
+ }, [dsList, dsId]);
+
+ /* Fetch triage results */
+ const fetchTriage = useCallback(async () => {
+ if (!dsId) return;
+ setLoadingTriage(true);
+ try {
+ const data = await analysis.triageResults(dsId);
+ setTriageResults(data);
+ } catch (e: any) {
+ enqueueSnackbar(`Triage fetch failed: ${e.message}`, { variant: 'error' });
+ } finally { setLoadingTriage(false); }
+ }, [dsId, enqueueSnackbar]);
+
+ const fetchProfiles = useCallback(async () => {
+ if (!huntId) return;
+ setLoadingProfiles(true);
+ try {
+ const data = await analysis.hostProfiles(huntId);
+ setProfiles(data);
+ } catch (e: any) {
+ enqueueSnackbar(`Profiles fetch failed: ${e.message}`, { variant: 'error' });
+ } finally { setLoadingProfiles(false); }
+ }, [huntId, enqueueSnackbar]);
+
+ const fetchReports = useCallback(async () => {
+ if (!huntId) return;
+ setLoadingReports(true);
+ try {
+ const data = await analysis.listReports(huntId);
+ setReports(data);
+ } catch (e: any) {
+ enqueueSnackbar(`Reports fetch failed: ${e.message}`, { variant: 'error' });
+ } finally { setLoadingReports(false); }
+ }, [huntId, enqueueSnackbar]);
+
+ const fetchAnomalies = useCallback(async () => {
+ if (!dsId) return;
+ setLoadingAnomalies(true);
+ try {
+ const data = await analysis.anomalies(dsId);
+ setAnomalies(data);
+ } catch (e: any) {
+ enqueueSnackbar('Anomaly fetch failed: ' + e.message, { variant: 'error' });
+ } finally { setLoadingAnomalies(false); }
+ }, [dsId, enqueueSnackbar]);
+
+ const fetchJobs = useCallback(async () => {
+ setLoadingJobs(true);
+ try {
+ const data = await analysis.listJobs();
+ setJobs(data.jobs);
+ setJobStats(data.stats);
+ } catch (e: any) {
+ enqueueSnackbar('Jobs fetch failed: ' + e.message, { variant: 'error' });
+ } finally { setLoadingJobs(false); }
+ }, [enqueueSnackbar]);
+
+ const fetchLbStatus = useCallback(async () => {
+ try {
+ const data = await analysis.lbStatus();
+ setLbStatus(data);
+ } catch {}
+ }, []);
+
+ // Load data when selectors change
+ useEffect(() => { if (dsId) fetchTriage(); }, [dsId, fetchTriage]);
+ useEffect(() => { if (huntId) { fetchProfiles(); fetchReports(); } }, [huntId, fetchProfiles, fetchReports]);
+
+ // Auto-refresh jobs when on jobs tab
+ useEffect(() => {
+ if (tab === 5) {
+ fetchJobs();
+ fetchLbStatus();
+ const iv = setInterval(() => { fetchJobs(); fetchLbStatus(); }, 5000);
+ return () => clearInterval(iv);
+ }
+ }, [tab, fetchJobs, fetchLbStatus]);
+
+ /* Trigger actions */
+ const doTriggerTriage = useCallback(async () => {
+ if (!dsId) return;
+ setTriggering(true);
+ try {
+ await analysis.triggerTriage(dsId);
+ enqueueSnackbar('Triage started', { variant: 'info' });
+ setTimeout(fetchTriage, 5000);
+ } catch (e: any) {
+ enqueueSnackbar(`Triage trigger failed: ${e.message}`, { variant: 'error' });
+ } finally { setTriggering(false); }
+ }, [dsId, enqueueSnackbar, fetchTriage]);
+
+ const doTriggerProfiles = useCallback(async () => {
+ if (!huntId) return;
+ setTriggering(true);
+ try {
+ await analysis.triggerAllProfiles(huntId);
+ enqueueSnackbar('Host profiling started', { variant: 'info' });
+ setTimeout(fetchProfiles, 10000);
+ } catch (e: any) {
+ enqueueSnackbar(`Profile trigger failed: ${e.message}`, { variant: 'error' });
+ } finally { setTriggering(false); }
+ }, [huntId, enqueueSnackbar, fetchProfiles]);
+
+ const doGenerateReport = useCallback(async () => {
+ if (!huntId) return;
+ setTriggering(true);
+ try {
+ await analysis.generateReport(huntId);
+ enqueueSnackbar('Report generation started', { variant: 'info' });
+ setTimeout(fetchReports, 15000);
+ } catch (e: any) {
+ enqueueSnackbar(`Report generation failed: ${e.message}`, { variant: 'error' });
+ } finally { setTriggering(false); }
+ }, [huntId, enqueueSnackbar, fetchReports]);
+
+ const doTriggerAnomalies = useCallback(async () => {
+ if (!dsId) return;
+ setTriggering(true);
+ try {
+ await analysis.triggerAnomalyDetection(dsId);
+ enqueueSnackbar('Anomaly detection started', { variant: 'info' });
+ setTimeout(fetchAnomalies, 20000);
+ } catch (e: any) {
+ enqueueSnackbar('Anomaly trigger failed: ' + e.message, { variant: 'error' });
+ } finally { setTriggering(false); }
+ }, [dsId, enqueueSnackbar, fetchAnomalies]);
+
+ /* Phase 9: Streaming data query */
+ const doQuery = useCallback(async () => {
+ if (!dsId || !queryText.trim()) return;
+ setQueryStreaming(true);
+ setQueryAnswer('');
+ setQueryMeta(null);
+ setQueryDone(null);
+
+ const controller = new AbortController();
+ abortRef.current = controller;
+
+ try {
+ const resp = await analysis.queryStream(dsId, queryText.trim(), queryMode);
+ if (!resp.body) throw new Error('No response body');
+
+ const reader = resp.body.getReader();
+ const decoder = new TextDecoder();
+ let buf = '';
+
+ while (true) {
+ const { done, value } = await reader.read();
+ if (done) break;
+ buf += decoder.decode(value, { stream: true });
+
+ const lines = buf.split('\n');
+ buf = lines.pop() || '';
+
+ for (const line of lines) {
+ if (!line.startsWith('data: ')) continue;
+ try {
+ const evt = JSON.parse(line.slice(6));
+ switch (evt.type) {
+ case 'token':
+ setQueryAnswer(prev => prev + evt.content);
+ if (answerRef.current) {
+ answerRef.current.scrollTop = answerRef.current.scrollHeight;
+ }
+ break;
+ case 'metadata':
+ setQueryMeta(evt.dataset);
+ break;
+ case 'done':
+ setQueryDone(evt);
+ break;
+ case 'error':
+ enqueueSnackbar(`Query error: ${evt.message}`, { variant: 'error' });
+ break;
+ }
+ } catch {}
+ }
+ }
+ } catch (e: any) {
+ if (e.name !== 'AbortError') {
+ enqueueSnackbar('Query failed: ' + e.message, { variant: 'error' });
+ }
+ } finally {
+ setQueryStreaming(false);
+ abortRef.current = null;
+ }
+ }, [dsId, queryText, queryMode, enqueueSnackbar]);
+
+ const stopQuery = useCallback(() => {
+ if (abortRef.current) {
+ abortRef.current.abort();
+ setQueryStreaming(false);
+ }
+ }, []);
+
+ /* Phase 10: Cancel job */
+ const doCancelJob = useCallback(async (jobId: string) => {
+ try {
+ await analysis.cancelJob(jobId);
+ enqueueSnackbar('Job cancelled', { variant: 'info' });
+ fetchJobs();
+ } catch (e: any) {
+ enqueueSnackbar('Cancel failed: ' + e.message, { variant: 'error' });
+ }
+ }, [enqueueSnackbar, fetchJobs]);
+
+ return (
+
+
+
+ AI Analysis
+ {triggering && }
+
+
+ {/* Selectors */}
+
+
+
+ Hunt
+
+
+
+ Dataset
+
+
+
+
+
+ {/* Tabs */}
+ setTab(v)} variant="scrollable" scrollButtons="auto" sx={{ mb: 1 }}>
+ } iconPosition="start" label={`Triage (${triageResults.length})`} />
+ } iconPosition="start" label={`Host Profiles (${profiles.length})`} />
+ } iconPosition="start" label={`Reports (${reports.length})`} />
+ } iconPosition="start" label={`Anomalies (${anomalies.filter(a => a.is_outlier).length})`} />
+ } iconPosition="start" label="Ask Data" />
+ } iconPosition="start" label={`Jobs${jobStats ? ` (${jobStats.active_workers})` : ''}`} />
+
+
+ {/* Tab 0: Triage */}
+
+
+ } onClick={doTriggerTriage}
+ disabled={!dsId || triggering} size="small">Run Triage
+ } onClick={fetchTriage}
+ disabled={!dsId || loadingTriage} size="small">Refresh
+
+ {loadingTriage && }
+ {triageResults.length === 0 && !loadingTriage ? (
+ No triage results yet. Select a dataset and click "Run Triage".
+ ) : (
+
+
+
+
+ RowsRiskVerdict
+ FindingsMITREModel
+
+
+
+ {triageResults.map(tr => (
+
+ {tr.row_start}-{tr.row_end}
+
+
+
+ {tr.findings?.join('; ') || ''}
+
+
+
+ {tr.mitre_techniques?.map((t: string, i: number) => (
+
+ ))}
+
+
+ {tr.model_used || ''}
+
+ ))}
+
+
+
+ )}
+
+
+ {/* Tab 1: Host Profiles */}
+
+
+ } onClick={doTriggerProfiles}
+ disabled={!huntId || triggering} size="small">Profile All Hosts
+ } onClick={fetchProfiles}
+ disabled={!huntId || loadingProfiles} size="small">Refresh
+
+ {loadingProfiles && }
+ {profiles.length === 0 && !loadingProfiles ? (
+ No host profiles yet. Select a hunt and click "Profile All Hosts".
+ ) : (
+
+ {profiles.map(hp => (
+
+
+
+
+ {hp.hostname}
+
+
+ {hp.fqdn && {hp.fqdn}}
+
+ {hp.timeline_summary && (
+
+ {hp.timeline_summary.slice(0, 300)}{hp.timeline_summary.length > 300 ? '...' : ''}
+
+ )}
+ {hp.suspicious_findings && hp.suspicious_findings.length > 0 && (
+
+
+
+ {hp.suspicious_findings.length} suspicious finding(s)
+
+
+ )}
+ {hp.mitre_techniques && hp.mitre_techniques.length > 0 && (
+
+ {hp.mitre_techniques.map((t: string, i: number) => (
+
+ ))}
+
+ )}
+
+
+ Model: {hp.model_used || 'N/A'}
+
+
+
+ ))}
+
+ )}
+
+
+ {/* Tab 2: Reports */}
+
+
+ } onClick={doGenerateReport}
+ disabled={!huntId || triggering} size="small">Generate Report
+ } onClick={fetchReports}
+ disabled={!huntId || loadingReports} size="small">Refresh
+
+ {loadingReports && }
+ {reports.length === 0 && !loadingReports ? (
+ No reports yet. Select a hunt and click "Generate Report".
+ ) : (
+ reports.map(rpt => (
+
+ }>
+
+
+ Report - {rpt.status}
+ {rpt.generation_time_ms && (
+
+ )}
+
+
+
+ {rpt.exec_summary && (
+
+ Executive Summary
+ {rpt.exec_summary}
+
+ )}
+ {rpt.findings && rpt.findings.length > 0 && (
+
+ Findings
+
+ {rpt.findings.map((f: any, i: number) => (
+ -
+ {typeof f === 'string' ? f : JSON.stringify(f)}
+
+ ))}
+
+
+ )}
+ {rpt.recommendations && rpt.recommendations.length > 0 && (
+
+ Recommendations
+
+ {rpt.recommendations.map((r: any, i: number) => (
+ -
+ {typeof r === 'string' ? r : JSON.stringify(r)}
+
+ ))}
+
+
+ )}
+ {rpt.ioc_table && rpt.ioc_table.length > 0 && (
+
+ IOC Table
+
+
+
+
+ {Object.keys(rpt.ioc_table[0]).map(k => (
+ {k}
+ ))}
+
+
+
+ {rpt.ioc_table.map((row: any, i: number) => (
+
+ {Object.values(row).map((v: any, j: number) => (
+ {String(v)}
+ ))}
+
+ ))}
+
+
+
+
+ )}
+ {rpt.full_report && (
+
+ }>
+ Full Report
+
+
+
+ {rpt.full_report}
+
+
+
+ )}
+
+ {rpt.models_used?.map((m: string, i: number) => (
+
+ ))}
+
+
+
+ ))
+ )}
+
+
+ {/* Tab 3: Anomalies */}
+
+
+ } onClick={doTriggerAnomalies}
+ disabled={!dsId || triggering} size="small">Detect Anomalies
+ } onClick={fetchAnomalies}
+ disabled={!dsId || loadingAnomalies} size="small">Refresh
+
+ {loadingAnomalies && }
+ {anomalies.length === 0 && !loadingAnomalies ? (
+ No anomaly results yet. Select a dataset and click "Detect Anomalies".
+ ) : (
+ <>
+
+ {anomalies.filter(a => a.is_outlier).length} outlier(s) detected out of {anomalies.length} rows
+
+
+
+
+
+ RowScore
+ DistanceClusterOutlier
+
+
+
+ {anomalies.filter(a => a.is_outlier).concat(anomalies.filter(a => !a.is_outlier).slice(0, 20)).map((a, i) => (
+
+ {a.row_id ?? ''}
+
+ 0.5 ? 'error' : a.anomaly_score > 0.35 ? 'warning' : 'success'} />
+
+ {a.distance_from_centroid?.toFixed(4) ?? ''}
+
+
+ {a.is_outlier
+ ?
+ : }
+
+
+ ))}
+
+
+
+ >
+ )}
+
+
+ {/* Tab 4: Ask Data (Phase 9) */}
+
+
+
+ Ask a question about the selected dataset in plain English
+
+
+ setQueryText(e.target.value)}
+ onKeyDown={e => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); doQuery(); } }}
+ disabled={queryStreaming}
+ />
+ { if (v) setQueryMode(v); }}
+ >
+
+ Quick
+
+
+ Deep
+
+
+ {queryStreaming ? (
+
+ ) : (
+
+
+
+ )}
+
+
+
+ {queryMeta && (
+
+ Querying {queryMeta.name} ({queryMeta.row_count} rows,{' '}
+ {queryMeta.sample_rows_shown} sampled) | Mode: {queryMode}
+
+ )}
+
+ {queryStreaming && }
+
+ {queryAnswer && (
+
+ {queryAnswer}
+
+ )}
+
+ {queryDone && (
+
+
+
+
+
+
+ )}
+
+
+ {/* Tab 5: Jobs & Load Balancer (Phase 10) */}
+
+ {/* LB Status Cards */}
+ {lbStatus && (
+
+ {Object.entries(lbStatus).map(([name, st]) => (
+
+
+
+
+ {name}
+
+
+
+ Active: {st.active_jobs}
+ Done: {st.total_completed}
+ Errors: {st.total_errors}
+ Avg: {st.avg_latency_ms.toFixed(0)}ms
+
+
+
+
+ ))}
+
+ )}
+
+ {/* Job queue stats */}
+ {jobStats && (
+
+
+
+ {Object.entries(jobStats.by_status).map(([s, c]) => (
+
+ ))}
+
+ )}
+
+
+ } onClick={fetchJobs}
+ disabled={loadingJobs} size="small">Refresh
+
+
+ {loadingJobs && }
+
+ {jobs.length === 0 && !loadingJobs ? (
+ No jobs yet. Jobs appear here when you trigger triage, profiling, reports, anomaly detection, or data queries.
+ ) : (
+
+
+
+
+ Status
+ Type
+ Progress
+ Message
+ Time
+ Created
+ Actions
+
+
+
+ {jobs.map(j => (
+
+
+
+ {statusIcon(j.status)}
+ {j.status}
+
+
+
+
+ {j.status === 'running' ? (
+
+ ) : j.status === 'completed' ? (
+ 100%
+ ) : null}
+
+
+ {j.error || j.message}
+
+ {fmtMs(j.elapsed_ms)}
+ {fmtTime(j.created_at)}
+
+ {(j.status === 'queued' || j.status === 'running') && (
+ doCancelJob(j.id)}>
+
+
+ )}
+
+
+ ))}
+
+
+
+ )}
+
+
+ );
+}
\ No newline at end of file
diff --git a/frontend/src/components/DatasetViewer.tsx b/frontend/src/components/DatasetViewer.tsx
index 6ee666b..2e06ea2 100644
--- a/frontend/src/components/DatasetViewer.tsx
+++ b/frontend/src/components/DatasetViewer.tsx
@@ -144,6 +144,12 @@ export default function DatasetViewer() {
{selected.source_tool && }
+ {selected.artifact_type && }
+ {selected.processing_status && selected.processing_status !== 'ready' && (
+
+ )}
{selected.ioc_columns && Object.keys(selected.ioc_columns).length > 0 && (
)}
diff --git a/frontend/src/components/NetworkMap.tsx b/frontend/src/components/NetworkMap.tsx
index 478bbc9..7ede886 100644
--- a/frontend/src/components/NetworkMap.tsx
+++ b/frontend/src/components/NetworkMap.tsx
@@ -1,477 +1,578 @@
/**
- * NetworkMap — interactive hunt-scoped force-directed network graph.
+ * NetworkMap - Host-centric force-directed network graph.
*
- * • Select a hunt → loads only that hunt's datasets
- * • Nodes = unique IPs / hostnames / domains pulled from IOC columns
- * • Edges = "seen together in the same row" co-occurrence
- * • Click a node → popover showing hostname, IP, OS, dataset sources, connections
- * • Responsive canvas with ResizeObserver
- * • Zero extra npm dependencies
+ * Loads a deduplicated host inventory from the backend, showing one node
+ * per unique host with hostname, IPs, OS, logged-in users, and network
+ * connections. No more duplicated IOC nodes.
+ *
+ * Features:
+ * - Calls /api/network/host-inventory for clean, deduped host data
+ * - HiDPI / Retina canvas rendering
+ * - Radial-gradient nodes with neon glow effects
+ * - Curved edges with animated flow on active connections
+ * - Animated force-directed layout
+ * - Node drag with springy neighbor physics
+ * - Glassmorphism toolbar + floating legend overlay
+ * - Rich popover: hostname, IP, OS, users, datasets
+ * - Zero extra npm dependencies
*/
import React, { useEffect, useState, useRef, useCallback, useMemo } from 'react';
import {
- Box, Typography, Paper, Stack, Alert, Chip, Button, TextField,
+ Box, Typography, Paper, Stack, Alert, Chip, TextField,
LinearProgress, FormControl, InputLabel, Select, MenuItem,
- Popover, Divider, IconButton,
+ Popover, Divider, IconButton, Tooltip, Fade, useTheme,
} from '@mui/material';
import RefreshIcon from '@mui/icons-material/Refresh';
import CloseIcon from '@mui/icons-material/Close';
import ZoomInIcon from '@mui/icons-material/ZoomIn';
import ZoomOutIcon from '@mui/icons-material/ZoomOut';
import CenterFocusStrongIcon from '@mui/icons-material/CenterFocusStrong';
-import { datasets, hunts, type Hunt, type DatasetSummary } from '../api/client';
+import HubIcon from '@mui/icons-material/Hub';
+import SearchIcon from '@mui/icons-material/Search';
+import ComputerIcon from '@mui/icons-material/Computer';
+import {
+ hunts, network,
+ type Hunt, type InventoryHost, type InventoryConnection, type InventoryStats,
+} from '../api/client';
-// ── Graph primitives ─────────────────────────────────────────────────
+// == Graph primitives =====================================================
-type NodeType = 'ip' | 'hostname' | 'domain' | 'url';
-
-interface NodeMeta {
- hostnames: Set;
- ips: Set;
- os: Set;
- datasets: Set;
- type: NodeType;
-}
+type NodeType = 'host' | 'external_ip';
interface GNode {
id: string; label: string; x: number; y: number;
vx: number; vy: number; radius: number; color: string; count: number;
- meta: { hostnames: string[]; ips: string[]; os: string[]; datasets: string[]; type: NodeType };
+ pinned?: boolean;
+ meta: {
+ type: NodeType;
+ hostname: string;
+ fqdn: string;
+ client_id: string;
+ ips: string[];
+ os: string;
+ users: string[];
+ datasets: string[];
+ row_count: number;
+ };
}
interface GEdge { source: string; target: string; weight: number }
interface Graph { nodes: GNode[]; edges: GEdge[] }
const TYPE_COLORS: Record = {
- ip: '#3b82f6', hostname: '#22c55e', domain: '#eab308', url: '#8b5cf6',
+ host: '#60a5fa',
+ external_ip: '#fbbf24',
+};
+const GLOW_COLORS: Record = {
+ host: 'rgba(96,165,250,0.45)',
+ external_ip: 'rgba(251,191,36,0.35)',
};
-// ── Helpers: find context columns from dataset schema ────────────────
+// == Build graph from inventory ==========================================
-/** Best-effort detection of hostname, IP, and OS columns from raw column names + normalized mapping. */
-function findContextColumns(ds: DatasetSummary) {
- const norm = ds.normalized_columns || {};
- const schema = ds.column_schema || {};
- const rawCols = Object.keys(schema).length > 0 ? Object.keys(schema) : Object.keys(norm);
+function buildGraphFromInventory(
+ hosts: InventoryHost[], connections: InventoryConnection[],
+ canvasW: number, canvasH: number,
+): Graph {
+ const nodeMap = new Map();
- const hostCols: string[] = [];
- const ipCols: string[] = [];
- const osCols: string[] = [];
-
- for (const raw of rawCols) {
- const canonical = norm[raw] || '';
- const lower = raw.toLowerCase();
- // Hostname columns
- if (canonical === 'hostname' || /^(hostname|host|fqdn|computer_?name|system_?name|machinename)$/i.test(lower)) {
- hostCols.push(raw);
- }
- // IP columns
- if (['src_ip', 'dst_ip', 'ip_address'].includes(canonical) || /^(ip|ip_?address|src_?ip|dst_?ip|source_?ip|dest_?ip)$/i.test(lower)) {
- ipCols.push(raw);
- }
- // OS columns (best-effort — raw name scan + normalized canonical)
- if (canonical === 'os' || /^(os|operating_?system|os_?version|os_?name|platform|os_?type)$/i.test(lower)) {
- osCols.push(raw);
- }
- }
- return { hostCols, ipCols, osCols };
-}
-
-function cleanVal(v: any): string {
- const s = (v ?? '').toString().trim();
- return (s && s !== '-' && s !== '0.0.0.0' && s !== '::') ? s : '';
-}
-
-// ── Build graph with per-node metadata ───────────────────────────────
-
-interface RowBatch {
- rows: Record[];
- iocColumns: Record;
- dsName: string;
- ds: DatasetSummary;
-}
-
-function buildGraph(allBatches: RowBatch[], canvasW: number, canvasH: number): Graph {
- const countMap = new Map();
- const edgeMap = new Map();
- const metaMap = new Map();
-
- const getOrCreateMeta = (id: string, type: NodeType): NodeMeta => {
- let m = metaMap.get(id);
- if (!m) { m = { hostnames: new Set(), ips: new Set(), os: new Set(), datasets: new Set(), type }; metaMap.set(id, m); }
- return m;
- };
-
- for (const { rows, iocColumns, dsName, ds } of allBatches) {
- // IOC columns that produce graph nodes
- const iocEntries = Object.entries(iocColumns).filter(([, t]) => {
- const typ = Array.isArray(t) ? t[0] : t;
- return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url';
- }).map(([col, t]) => {
- const typ = (Array.isArray(t) ? t[0] : t) as NodeType;
- return { col, typ };
- });
-
- if (iocEntries.length === 0) continue;
-
- // Context columns for enrichment
- const ctx = findContextColumns(ds);
-
- for (const row of rows) {
- // Collect IOC values for this row (nodes + edges)
- const vals: { v: string; typ: NodeType }[] = [];
- for (const { col, typ } of iocEntries) {
- const v = cleanVal(row[col]);
- if (v) vals.push({ v, typ });
- }
- const unique = [...new Map(vals.map(x => [x.v, x])).values()];
-
- // Count occurrences
- for (const { v } of unique) countMap.set(v, (countMap.get(v) ?? 0) + 1);
-
- // Create edges (co-occurrence)
- for (let i = 0; i < unique.length; i++) {
- for (let j = i + 1; j < unique.length; j++) {
- const key = [unique[i].v, unique[j].v].sort().join('||');
- edgeMap.set(key, (edgeMap.get(key) ?? 0) + 1);
- }
- }
-
- // Extract context values from this row
- const rowHosts = ctx.hostCols.map(c => cleanVal(row[c])).filter(Boolean);
- const rowIps = ctx.ipCols.map(c => cleanVal(row[c])).filter(Boolean);
- const rowOs = ctx.osCols.map(c => cleanVal(row[c])).filter(Boolean);
-
- // Attach context to each node in this row
- for (const { v, typ } of unique) {
- const meta = getOrCreateMeta(v, typ);
- meta.datasets.add(dsName);
- for (const h of rowHosts) meta.hostnames.add(h);
- for (const ip of rowIps) meta.ips.add(ip);
- for (const o of rowOs) meta.os.add(o);
- }
- }
- }
-
- const nodes: GNode[] = [...countMap.entries()].map(([id, count]) => {
- const raw = metaMap.get(id);
- const type: NodeType = raw?.type || 'ip';
- return {
- id, label: id, count,
+ // Create host nodes
+ for (const h of hosts) {
+ const r = Math.max(8, Math.min(26, 6 + Math.sqrt(h.row_count / 100) * 3));
+ nodeMap.set(h.id, {
+ id: h.id,
+ label: h.hostname || h.fqdn || h.client_id,
x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75,
y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65,
- vx: 0, vy: 0,
- radius: Math.max(5, Math.min(18, 4 + Math.sqrt(count) * 1.6)),
- color: TYPE_COLORS[type],
+ vx: 0, vy: 0, radius: r,
+ color: TYPE_COLORS.host,
+ count: h.row_count,
meta: {
- hostnames: [...(raw?.hostnames ?? [])],
- ips: [...(raw?.ips ?? [])],
- os: [...(raw?.os ?? [])],
- datasets: [...(raw?.datasets ?? [])],
- type,
+ type: 'host' as NodeType,
+ hostname: h.hostname,
+ fqdn: h.fqdn,
+ client_id: h.client_id,
+ ips: h.ips,
+ os: h.os,
+ users: h.users,
+ datasets: h.datasets,
+ row_count: h.row_count,
},
- };
- });
+ });
+ }
- const edges: GEdge[] = [...edgeMap.entries()].map(([key, weight]) => {
- const [source, target] = key.split('||');
- return { source, target, weight };
- });
+ // Create edges + external IP nodes (for unresolved remote IPs)
+ const edges: GEdge[] = [];
+ for (const c of connections) {
+ if (!nodeMap.has(c.target)) {
+ nodeMap.set(c.target, {
+ id: c.target,
+ label: c.target_ip || c.target,
+ x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75,
+ y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65,
+ vx: 0, vy: 0, radius: 6,
+ color: TYPE_COLORS.external_ip,
+ count: c.count,
+ meta: {
+ type: 'external_ip' as NodeType,
+ hostname: '', fqdn: '', client_id: '',
+ ips: [c.target_ip || c.target],
+ os: '', users: [], datasets: [], row_count: 0,
+ },
+ });
+ }
+ edges.push({ source: c.source, target: c.target, weight: c.count });
+ }
- return { nodes, edges };
+ return { nodes: [...nodeMap.values()], edges };
}
-// ── Force simulation ─────────────────────────────────────────────────
+// == Simulation ===========================================================
-function simulate(graph: Graph, cx: number, cy: number, steps = 120) {
+function simulationStep(graph: Graph, cx: number, cy: number, alpha: number) {
const { nodes, edges } = graph;
const nodeMap = new Map(nodes.map(n => [n.id, n]));
- const k = 80;
- const repulsion = 6000;
- const damping = 0.85;
+ const k = 120;
+ const repulsion = 12000;
+ const damping = 0.82;
- for (let step = 0; step < steps; step++) {
- for (let i = 0; i < nodes.length; i++) {
- for (let j = i + 1; j < nodes.length; j++) {
- const a = nodes[i], b = nodes[j];
- const dx = b.x - a.x, dy = b.y - a.y;
- const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy));
- const force = repulsion / (dist * dist);
- const fx = (dx / dist) * force, fy = (dy / dist) * force;
- a.vx -= fx; a.vy -= fy;
- b.vx += fx; b.vy += fy;
- }
- }
- for (const e of edges) {
- const a = nodeMap.get(e.source), b = nodeMap.get(e.target);
- if (!a || !b) continue;
+ for (let i = 0; i < nodes.length; i++) {
+ for (let j = i + 1; j < nodes.length; j++) {
+ const a = nodes[i], b = nodes[j];
+ if (a.pinned && b.pinned) continue;
const dx = b.x - a.x, dy = b.y - a.y;
const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy));
- const force = (dist - k) * 0.05;
+ const force = (repulsion * alpha) / (dist * dist);
const fx = (dx / dist) * force, fy = (dy / dist) * force;
- a.vx += fx; a.vy += fy;
- b.vx -= fx; b.vy -= fy;
- }
- for (const n of nodes) {
- n.vx += (cx - n.x) * 0.001;
- n.vy += (cy - n.y) * 0.001;
- n.vx *= damping; n.vy *= damping;
- n.x += n.vx; n.y += n.vy;
+ if (!a.pinned) { a.vx -= fx; a.vy -= fy; }
+ if (!b.pinned) { b.vx += fx; b.vy += fy; }
}
}
+ for (const e of edges) {
+ const a = nodeMap.get(e.source), b = nodeMap.get(e.target);
+ if (!a || !b) continue;
+ const dx = b.x - a.x, dy = b.y - a.y;
+ const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy));
+ const force = (dist - k) * 0.06 * alpha;
+ const fx = (dx / dist) * force, fy = (dy / dist) * force;
+ if (!a.pinned) { a.vx += fx; a.vy += fy; }
+ if (!b.pinned) { b.vx -= fx; b.vy -= fy; }
+ }
+ for (const n of nodes) {
+ if (n.pinned) continue;
+ n.vx += (cx - n.x) * 0.0012 * alpha;
+ n.vy += (cy - n.y) * 0.0012 * alpha;
+ n.vx *= damping; n.vy *= damping;
+ n.x += n.vx; n.y += n.vy;
+ }
}
-// ── Viewport (zoom / pan) ────────────────────────────────────────────
+function simulate(graph: Graph, cx: number, cy: number, steps = 150) {
+ for (let i = 0; i < steps; i++) {
+ const alpha = 1 - i / steps;
+ simulationStep(graph, cx, cy, Math.max(0.05, alpha));
+ }
+}
+
+// == Viewport =============================================================
interface Viewport { x: number; y: number; scale: number }
+const MIN_ZOOM = 0.08;
+const MAX_ZOOM = 10;
-const MIN_ZOOM = 0.1;
-const MAX_ZOOM = 8;
+// == Canvas renderer =====================================================
-// ── Canvas renderer ──────────────────────────────────────────────────
+const BG_COLOR = '#0a101e';
+const GRID_DOT_COLOR = 'rgba(148,163,184,0.04)';
+const GRID_SPACING = 32;
-function drawGraph(
- ctx: CanvasRenderingContext2D, graph: Graph,
- hovered: string | null, selected: string | null, search: string,
- vp: Viewport,
+function drawBackground(
+ ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
) {
- const { nodes, edges } = graph;
- const nodeMap = new Map(nodes.map(n => [n.id, n]));
- const matchSet = new Set();
- if (search) {
- const lc = search.toLowerCase();
- for (const n of nodes) if (n.label.toLowerCase().includes(lc)) matchSet.add(n.id);
- }
-
- ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
+ ctx.fillStyle = BG_COLOR;
+ ctx.fillRect(0, 0, w, h);
ctx.save();
- ctx.translate(vp.x, vp.y);
- ctx.scale(vp.scale, vp.scale);
+ ctx.translate(vp.x * dpr, vp.y * dpr);
+ ctx.scale(vp.scale * dpr, vp.scale * dpr);
+ const startX = -vp.x / vp.scale - GRID_SPACING;
+ const startY = -vp.y / vp.scale - GRID_SPACING;
+ const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
+ const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
+ ctx.fillStyle = GRID_DOT_COLOR;
+ for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
+ for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
+ ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
+ }
+ }
+ ctx.restore();
+ const vignette = ctx.createRadialGradient(w / 2, h / 2, w * 0.2, w / 2, h / 2, w * 0.7);
+ vignette.addColorStop(0, 'rgba(10,16,30,0)');
+ vignette.addColorStop(1, 'rgba(10,16,30,0.55)');
+ ctx.fillStyle = vignette;
+ ctx.fillRect(0, 0, w, h);
+}
- // Edges
- for (const e of edges) {
+function drawEdges(
+ ctx: CanvasRenderingContext2D, graph: Graph,
+ hovered: string | null, selected: string | null,
+ nodeMap: Map, animTime: number,
+) {
+ for (const e of graph.edges) {
const a = nodeMap.get(e.source), b = nodeMap.get(e.target);
if (!a || !b) continue;
const isActive = (hovered && (e.source === hovered || e.target === hovered))
|| (selected && (e.source === selected || e.target === selected));
- ctx.beginPath();
- ctx.strokeStyle = isActive ? 'rgba(96,165,250,0.7)' : 'rgba(100,116,139,0.25)';
- ctx.lineWidth = Math.min(4, 0.5 + e.weight * 0.3) / vp.scale;
- ctx.moveTo(a.x, a.y); ctx.lineTo(b.x, b.y); ctx.stroke();
- }
+ const mx = (a.x + b.x) / 2, my = (a.y + b.y) / 2;
+ const dx = b.x - a.x, dy = b.y - a.y;
+ const len = Math.sqrt(dx * dx + dy * dy);
+ const perpScale = Math.min(20, len * 0.08);
+ const cpx = mx + (-dy / (len || 1)) * perpScale;
+ const cpy = my + (dx / (len || 1)) * perpScale;
- // Nodes
- for (const n of nodes) {
- const highlighted = hovered === n.id || selected === n.id || (search && matchSet.has(n.id));
- ctx.beginPath();
- ctx.arc(n.x, n.y, n.radius, 0, Math.PI * 2);
- ctx.fillStyle = highlighted ? '#fff' : n.color;
- ctx.globalAlpha = (search && !matchSet.has(n.id)) ? 0.15 : 1;
- ctx.fill();
- ctx.globalAlpha = 1;
- if (highlighted) { ctx.strokeStyle = n.color; ctx.lineWidth = 2.5 / vp.scale; ctx.stroke(); }
- }
+ ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);
- // Labels — show more labels when zoomed in
- const labelThreshold = Math.max(1, Math.round(3 / vp.scale));
- const fontSize = Math.max(8, Math.round(11 / vp.scale));
- ctx.font = `${fontSize}px Inter, sans-serif`;
+ if (isActive) {
+ ctx.strokeStyle = 'rgba(96,165,250,0.8)';
+ ctx.lineWidth = Math.min(3.5, 1 + e.weight * 0.15);
+ ctx.setLineDash([6, 4]); ctx.lineDashOffset = -animTime * 0.03;
+ ctx.stroke(); ctx.setLineDash([]);
+ ctx.save();
+ ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
+ ctx.strokeStyle = 'rgba(96,165,250,0.3)';
+ ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
+ ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);
+ ctx.stroke(); ctx.restore();
+ } else {
+ const alpha = Math.min(0.35, 0.08 + e.weight * 0.01);
+ ctx.strokeStyle = `rgba(100,116,139,${alpha})`;
+ ctx.lineWidth = Math.min(2.5, 0.4 + e.weight * 0.08);
+ ctx.stroke();
+ }
+ }
+}
+
+function drawNodes(
+ ctx: CanvasRenderingContext2D, graph: Graph,
+ hovered: string | null, selected: string | null,
+ search: string, matchSet: Set,
+) {
+ const dimmed = search.length > 0;
+ for (const n of graph.nodes) {
+ const isHighlight = hovered === n.id || selected === n.id || (search && matchSet.has(n.id));
+ const isDim = dimmed && !matchSet.has(n.id);
+ ctx.save();
+ ctx.globalAlpha = isDim ? 0.12 : 1;
+
+ if (isHighlight && !isDim) {
+ ctx.save();
+ ctx.shadowColor = GLOW_COLORS[n.meta.type] || 'rgba(96,165,250,0.4)';
+ ctx.shadowBlur = 18;
+ ctx.beginPath(); ctx.arc(n.x, n.y, n.radius + 4, 0, Math.PI * 2);
+ ctx.fillStyle = 'rgba(0,0,0,0)'; ctx.fill();
+ ctx.restore();
+ }
+
+ const grad = ctx.createRadialGradient(
+ n.x - n.radius * 0.3, n.y - n.radius * 0.3, n.radius * 0.1, n.x, n.y, n.radius,
+ );
+ if (isHighlight && !isDim) {
+ grad.addColorStop(0, '#ffffff');
+ grad.addColorStop(0.4, n.color);
+ grad.addColorStop(1, n.color);
+ } else {
+ grad.addColorStop(0, n.color + 'cc');
+ grad.addColorStop(0.5, n.color);
+ grad.addColorStop(1, n.color + '88');
+ }
+ ctx.beginPath(); ctx.arc(n.x, n.y, n.radius, 0, Math.PI * 2);
+ ctx.fillStyle = grad; ctx.fill();
+ ctx.strokeStyle = isHighlight ? '#ffffff' : (n.color + '55');
+ ctx.lineWidth = isHighlight ? 2 : 1;
+ ctx.stroke();
+
+ if (n.pinned) {
+ ctx.beginPath(); ctx.arc(n.x, n.y, 2.5, 0, Math.PI * 2);
+ ctx.fillStyle = '#ffffff'; ctx.fill();
+ }
+ ctx.restore();
+ }
+}
+
+function drawLabels(
+ ctx: CanvasRenderingContext2D, graph: Graph,
+ hovered: string | null, selected: string | null,
+ search: string, matchSet: Set, vp: Viewport,
+) {
+ const dimmed = search.length > 0;
+ const fontSize = Math.max(9, Math.round(12 / vp.scale));
+ ctx.font = `500 ${fontSize}px Inter, system-ui, sans-serif`;
ctx.textAlign = 'center';
- for (const n of nodes) {
- const show = hovered === n.id || selected === n.id
- || (search && matchSet.has(n.id)) || n.count >= labelThreshold;
- if (!show) continue;
- ctx.fillStyle = (search && !matchSet.has(n.id)) ? 'rgba(241,245,249,0.15)' : '#f1f5f9';
- ctx.fillText(n.label, n.x, n.y - n.radius - 5);
- }
+ ctx.textBaseline = 'bottom';
+ const sorted = [...graph.nodes].sort((a, b) => {
+ const aH = hovered === a.id || selected === a.id || matchSet.has(a.id) ? 1 : 0;
+ const bH = hovered === b.id || selected === b.id || matchSet.has(b.id) ? 1 : 0;
+ if (aH !== bH) return aH - bH;
+ return b.count - a.count;
+ });
+
+ for (const n of sorted) {
+ const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
+ // Always show labels for hosts (since they're deduped and fewer)
+ const show = isHighlight || n.meta.type === 'host' || n.count >= 2;
+ if (!show) continue;
+ const isDim = dimmed && !matchSet.has(n.id);
+ if (isDim) continue;
+
+ // Two-line label: hostname + IP (if available)
+ const line1 = n.label;
+ const line2 = n.meta.ips.length > 0 ? n.meta.ips[0] : '';
+ const tw = Math.max(ctx.measureText(line1).width, line2 ? ctx.measureText(line2).width : 0);
+ const px = 5, py = 2;
+ const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2;
+ const lx = n.x, ly = n.y - n.radius - 6;
+
+ const rx = lx - tw / 2 - px;
+ const ry = ly - totalH;
+ const rw = tw + px * 2;
+ const rh = totalH;
+ const cr = 4;
+
+ ctx.save();
+ ctx.globalAlpha = isHighlight ? 0.92 : 0.75;
+ ctx.fillStyle = 'rgba(10,16,30,0.80)';
+ ctx.beginPath();
+ ctx.moveTo(rx + cr, ry); ctx.lineTo(rx + rw - cr, ry);
+ ctx.arcTo(rx + rw, ry, rx + rw, ry + cr, cr);
+ ctx.lineTo(rx + rw, ry + rh - cr);
+ ctx.arcTo(rx + rw, ry + rh, rx + rw - cr, ry + rh, cr);
+ ctx.lineTo(rx + cr, ry + rh);
+ ctx.arcTo(rx, ry + rh, rx, ry + rh - cr, cr);
+ ctx.lineTo(rx, ry + cr);
+ ctx.arcTo(rx, ry, rx + cr, ry, cr);
+ ctx.closePath(); ctx.fill();
+ ctx.strokeStyle = isHighlight ? n.color + 'aa' : 'rgba(148,163,184,0.15)';
+ ctx.lineWidth = 0.8; ctx.stroke();
+ ctx.restore();
+
+ // Hostname line
+ ctx.fillStyle = isHighlight ? '#ffffff' : n.color;
+ ctx.globalAlpha = isHighlight ? 1 : 0.85;
+ ctx.fillText(line1, lx, ly - (line2 ? fontSize * 0.5 : 0));
+ // IP line (smaller, dimmer)
+ if (line2) {
+ ctx.fillStyle = 'rgba(148,163,184,0.6)';
+ ctx.fillText(line2, lx, ly + fontSize * 0.5);
+ }
+ ctx.globalAlpha = 1;
+ }
+}
+
+function drawGraph(
+ ctx: CanvasRenderingContext2D, graph: Graph,
+ hovered: string | null, selected: string | null, search: string,
+ vp: Viewport, animTime: number, dpr: number,
+) {
+ const w = ctx.canvas.width, h = ctx.canvas.height;
+ const nodeMap = new Map(graph.nodes.map(n => [n.id, n]));
+ const matchSet = new Set();
+ if (search) {
+ const lc = search.toLowerCase();
+ for (const n of graph.nodes) {
+ if (n.label.toLowerCase().includes(lc)
+ || n.meta.ips.some(ip => ip.includes(lc))
+ || n.meta.users.some(u => u.toLowerCase().includes(lc))
+ || n.meta.os.toLowerCase().includes(lc)
+ ) matchSet.add(n.id);
+ }
+ }
+ drawBackground(ctx, w, h, vp, dpr);
+ ctx.save();
+ ctx.translate(vp.x * dpr, vp.y * dpr);
+ ctx.scale(vp.scale * dpr, vp.scale * dpr);
+ drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
+ drawNodes(ctx, graph, hovered, selected, search, matchSet);
+ drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
ctx.restore();
}
-// ── Hit-test helper (viewport-aware) ─────────────────────────────────
+// == Hit-test =============================================================
function screenToWorld(
canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): { wx: number; wy: number } {
const rect = canvas.getBoundingClientRect();
- const cssToCanvas_x = canvas.width / rect.width;
- const cssToCanvas_y = canvas.height / rect.height;
- const cx = (clientX - rect.left) * cssToCanvas_x;
- const cy = (clientY - rect.top) * cssToCanvas_y;
- return { wx: (cx - vp.x) / vp.scale, wy: (cy - vp.y) / vp.scale };
+ return { wx: (clientX - rect.left - vp.x) / vp.scale, wy: (clientY - rect.top - vp.y) / vp.scale };
}
function hitTest(
- graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number,
- vp: Viewport,
+ graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): GNode | null {
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
for (const n of graph.nodes) {
const dx = n.x - wx, dy = n.y - wy;
- if (dx * dx + dy * dy < (n.radius + 4) ** 2) return n;
+ if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
}
return null;
}
-
-// ── Component ────────────────────────────────────────────────────────
+// == Component =============================================================
export default function NetworkMap() {
- // Hunt selector
+ const theme = useTheme();
+
const [huntList, setHuntList] = useState([]);
const [selectedHuntId, setSelectedHuntId] = useState('');
- // Graph state
const [loading, setLoading] = useState(false);
const [progress, setProgress] = useState('');
const [error, setError] = useState('');
const [graph, setGraph] = useState(null);
+ const [stats, setStats] = useState(null);
const [hovered, setHovered] = useState(null);
const [selectedNode, setSelectedNode] = useState(null);
const [search, setSearch] = useState('');
- const [dsCount, setDsCount] = useState(0);
- const [totalRows, setTotalRows] = useState(0);
- // Node type filters
- const [visibleTypes, setVisibleTypes] = useState>(
- new Set(['ip', 'hostname', 'domain', 'url']),
- );
-
- // Canvas sizing
const canvasRef = useRef(null);
const wrapperRef = useRef(null);
const [canvasSize, setCanvasSize] = useState({ w: 900, h: 600 });
- // Viewport (zoom / pan)
const vpRef = useRef({ x: 0, y: 0, scale: 1 });
- const [vpScale, setVpScale] = useState(1); // for UI display only
+ const [vpScale, setVpScale] = useState(1);
const isPanning = useRef(false);
const panStart = useRef({ x: 0, y: 0 });
+ const dragNode = useRef(null);
+
+ const animFrameRef = useRef(0);
+ const animTimeRef = useRef(0);
+ const simAlphaRef = useRef(0);
+ const isAnimatingRef = useRef(false);
+ const hoveredRef = useRef(null);
+ const selectedNodeRef = useRef(null);
+ const searchRef = useRef('');
+ const graphRef = useRef(null);
- // Popover anchor
const [popoverAnchor, setPopoverAnchor] = useState<{ top: number; left: number } | null>(null);
- // ── Load hunts on mount ────────────────────────────────────────────
+ useEffect(() => { hoveredRef.current = hovered; }, [hovered]);
+ useEffect(() => { selectedNodeRef.current = selectedNode; }, [selectedNode]);
+ useEffect(() => { searchRef.current = search; }, [search]);
+
+ // Load hunts on mount
useEffect(() => {
hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {});
}, []);
- // ── Resize observer ────────────────────────────────────────────────
+ // Resize observer
useEffect(() => {
const el = wrapperRef.current;
if (!el) return;
const ro = new ResizeObserver(entries => {
for (const entry of entries) {
const w = Math.round(entry.contentRect.width);
- if (w > 100) setCanvasSize({ w, h: Math.max(450, Math.round(w * 0.55)) });
+ if (w > 100) setCanvasSize({ w, h: Math.max(500, Math.round(w * 0.56)) });
}
});
ro.observe(el);
return () => ro.disconnect();
}, []);
- // ── Load graph for selected hunt ──────────────────────────────────
+ // HiDPI canvas sizing
+ useEffect(() => {
+ const canvas = canvasRef.current;
+ if (!canvas) return;
+ const dpr = window.devicePixelRatio || 1;
+ canvas.width = canvasSize.w * dpr;
+ canvas.height = canvasSize.h * dpr;
+ canvas.style.width = canvasSize.w + 'px';
+ canvas.style.height = canvasSize.h + 'px';
+ }, [canvasSize]);
+
+ // Load host inventory for selected hunt
const loadGraph = useCallback(async (huntId: string) => {
if (!huntId) return;
- setLoading(true); setError(''); setGraph(null);
+ setLoading(true); setError(''); setGraph(null); setStats(null);
setSelectedNode(null); setPopoverAnchor(null);
try {
- setProgress('Fetching datasets for hunt…');
- const dsRes = await datasets.list(0, 500, huntId);
- const dsList = dsRes.datasets;
- setDsCount(dsList.length);
+ setProgress('Building host inventory (scanning all datasets)\u2026');
+ const inv = await network.hostInventory(huntId);
+ setStats(inv.stats);
- if (dsList.length === 0) {
- setError('This hunt has no datasets. Upload CSV files to this hunt first.');
+ if (inv.hosts.length === 0) {
+ setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');
setLoading(false); setProgress('');
return;
}
- const allBatches: RowBatch[] = [];
- let rowTotal = 0;
-
- for (let i = 0; i < dsList.length; i++) {
- const ds = dsList[i];
- setProgress(`Loading ${ds.name} (${i + 1}/${dsList.length})…`);
- try {
- const detail = await datasets.get(ds.id);
- const ioc = detail.ioc_columns || {};
- const hasIoc = Object.values(ioc).some(t => {
- const typ = Array.isArray(t) ? t[0] : t;
- return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url';
- });
- if (hasIoc) {
- const r = await datasets.rows(ds.id, 0, 5000);
- allBatches.push({ rows: r.rows, iocColumns: ioc, dsName: ds.name, ds: detail });
- rowTotal += r.rows.length;
- }
- } catch { /* skip failed datasets */ }
- }
-
- setTotalRows(rowTotal);
-
- if (allBatches.length === 0) {
- setError('No datasets in this hunt contain IP/hostname/domain IOC columns.');
- setLoading(false); setProgress('');
- return;
- }
-
- setProgress('Building graph…');
- const g = buildGraph(allBatches, canvasSize.w, canvasSize.h);
- if (g.nodes.length === 0) {
- setError('No network nodes found in the data.');
- } else {
- simulate(g, canvasSize.w / 2, canvasSize.h / 2);
- setGraph(g);
- }
+ setProgress(`Building graph for ${inv.stats.total_hosts} hosts\u2026`);
+ const g = buildGraphFromInventory(inv.hosts, inv.connections, canvasSize.w, canvasSize.h);
+ simulate(g, canvasSize.w / 2, canvasSize.h / 2, 30);
+ simAlphaRef.current = 0.8;
+ setGraph(g);
} catch (e: any) { setError(e.message); }
setLoading(false); setProgress('');
}, [canvasSize]);
- // When hunt changes, load graph
useEffect(() => {
if (selectedHuntId) loadGraph(selectedHuntId);
}, [selectedHuntId, loadGraph]);
- // Reset viewport when graph changes
useEffect(() => {
vpRef.current = { x: 0, y: 0, scale: 1 };
setVpScale(1);
}, [graph]);
- // Filtered graph — only visible node types + edges between them
- const filteredGraph = useMemo(() => {
- if (!graph) return null;
- const nodes = graph.nodes.filter(n => visibleTypes.has(n.meta.type));
- const nodeIds = new Set(nodes.map(n => n.id));
- const edges = graph.edges.filter(e => nodeIds.has(e.source) && nodeIds.has(e.target));
- return { nodes, edges };
- }, [graph, visibleTypes]);
+ useEffect(() => { graphRef.current = graph; }, [graph]);
- // Toggle a node type filter
- const toggleType = useCallback((t: NodeType) => {
- setVisibleTypes(prev => {
- const next = new Set(prev);
- if (next.has(t)) {
- // Don't allow all to be hidden
- if (next.size > 1) next.delete(t);
- } else {
- next.add(t);
+ // Animation loop
+ const startAnimLoop = useCallback(() => {
+ if (isAnimatingRef.current) return;
+ isAnimatingRef.current = true;
+ const tick = (ts: number) => {
+ animTimeRef.current = ts;
+ const canvas = canvasRef.current;
+ const g = graphRef.current;
+ if (!canvas || !g) { isAnimatingRef.current = false; return; }
+ const dpr = window.devicePixelRatio || 1;
+ const ctx = canvas.getContext('2d');
+ if (!ctx) { isAnimatingRef.current = false; return; }
+
+ if (simAlphaRef.current > 0.01) {
+ simulationStep(g, canvasSize.w / 2, canvasSize.h / 2, simAlphaRef.current);
+ simAlphaRef.current *= 0.97;
+ if (simAlphaRef.current < 0.01) simAlphaRef.current = 0;
}
- return next;
- });
- }, []);
+ drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);
+
+ const needsAnim = simAlphaRef.current > 0.01
+ || hoveredRef.current !== null
+ || selectedNodeRef.current !== null
+ || dragNode.current !== null;
+ if (needsAnim) {
+ animFrameRef.current = requestAnimationFrame(tick);
+ } else {
+ isAnimatingRef.current = false;
+ }
+ };
+ animFrameRef.current = requestAnimationFrame(tick);
+ }, [canvasSize]);
+
+ useEffect(() => {
+ if (graph) startAnimLoop();
+ return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
+ }, [graph, startAnimLoop]);
+
+ useEffect(() => { startAnimLoop(); }, [hovered, selectedNode, search, startAnimLoop]);
- // Redraw helper — uses filteredGraph
const redraw = useCallback(() => {
- if (!filteredGraph || !canvasRef.current) return;
+ if (!graph || !canvasRef.current) return;
const ctx = canvasRef.current.getContext('2d');
- if (ctx) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vpRef.current);
- }, [filteredGraph, hovered, selectedNode, search]);
+ const dpr = window.devicePixelRatio || 1;
+ if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);
+ }, [graph, hovered, selectedNode, search]);
- // Redraw on every render-affecting state change
- useEffect(() => { redraw(); }, [redraw]);
+ useEffect(() => { if (!isAnimatingRef.current) redraw(); }, [redraw]);
- // ── Mouse wheel → zoom ─────────────────────────────────────────────
+ // Mouse wheel -> zoom
useEffect(() => {
const canvas = canvasRef.current;
if (!canvas) return;
@@ -479,210 +580,329 @@ export default function NetworkMap() {
e.preventDefault();
const vp = vpRef.current;
const rect = canvas.getBoundingClientRect();
- const cssToCanvasX = canvas.width / rect.width;
- const cssToCanvasY = canvas.height / rect.height;
- // Mouse position in canvas pixel coords
- const mx = (e.clientX - rect.left) * cssToCanvasX;
- const my = (e.clientY - rect.top) * cssToCanvasY;
-
- const zoomFactor = e.deltaY < 0 ? 1.12 : 1 / 1.12;
+ const mx = e.clientX - rect.left, my = e.clientY - rect.top;
+ const zoomFactor = e.deltaY < 0 ? 1.15 : 1 / 1.15;
const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * zoomFactor));
- // Zoom toward cursor: adjust offset so world-point under cursor stays fixed
vp.x = mx - (mx - vp.x) * (newScale / vp.scale);
vp.y = my - (my - vp.y) * (newScale / vp.scale);
vp.scale = newScale;
setVpScale(newScale);
- // Immediate redraw (bypass React state for smoothness)
- const ctx = canvas.getContext('2d');
- if (ctx && filteredGraph) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vp);
+ startAnimLoop();
};
canvas.addEventListener('wheel', onWheel, { passive: false });
return () => canvas.removeEventListener('wheel', onWheel);
- }, [filteredGraph, hovered, selectedNode, search]);
+ }, [graph, startAnimLoop]);
- // ── Mouse drag → pan ───────────────────────────────────────────────
+ // Mouse handlers
const onMouseDown = useCallback((e: React.MouseEvent) => {
- if (!filteredGraph || !canvasRef.current) return;
- const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
- if (!node) {
- isPanning.current = true;
- panStart.current = { x: e.clientX, y: e.clientY };
- }
- }, [filteredGraph]);
+ if (!graph || !canvasRef.current) return;
+ const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
+ if (node) { dragNode.current = node; node.pinned = true; startAnimLoop(); }
+ else { isPanning.current = true; panStart.current = { x: e.clientX, y: e.clientY }; }
+ }, [graph, startAnimLoop]);
const onMouseMove = useCallback((e: React.MouseEvent) => {
- if (!filteredGraph || !canvasRef.current) return;
-
+ if (!graph || !canvasRef.current) return;
+ if (dragNode.current) {
+ const { wx, wy } = screenToWorld(canvasRef.current, e.clientX, e.clientY, vpRef.current);
+ dragNode.current.x = wx; dragNode.current.y = wy;
+ dragNode.current.vx = 0; dragNode.current.vy = 0;
+ if (simAlphaRef.current < 0.15) simAlphaRef.current = 0.15;
+ startAnimLoop(); return;
+ }
if (isPanning.current) {
const vp = vpRef.current;
- const rect = canvasRef.current.getBoundingClientRect();
- const cssToCanvasX = canvasRef.current.width / rect.width;
- const cssToCanvasY = canvasRef.current.height / rect.height;
- vp.x += (e.clientX - panStart.current.x) * cssToCanvasX;
- vp.y += (e.clientY - panStart.current.y) * cssToCanvasY;
+ vp.x += e.clientX - panStart.current.x;
+ vp.y += e.clientY - panStart.current.y;
panStart.current = { x: e.clientX, y: e.clientY };
- redraw();
- return;
+ redraw(); return;
}
-
- const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
+ const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
setHovered(node?.id ?? null);
- }, [filteredGraph, redraw]);
+ }, [graph, redraw, startAnimLoop]);
- const onMouseUp = useCallback(() => {
- isPanning.current = false;
- }, []);
+ const onMouseUp = useCallback(() => { dragNode.current = null; isPanning.current = false; }, []);
- // ── Mouse click → select node + show popover ─────────────────────
const onClick = useCallback((e: React.MouseEvent) => {
- if (!filteredGraph || !canvasRef.current) return;
- const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
- if (node) {
- setSelectedNode(node);
- setPopoverAnchor({ top: e.clientY, left: e.clientX });
- } else {
- setSelectedNode(null);
- setPopoverAnchor(null);
+ if (!graph || !canvasRef.current) return;
+ const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
+ if (node) { setSelectedNode(node); setPopoverAnchor({ top: e.clientY, left: e.clientX }); }
+ else { setSelectedNode(null); setPopoverAnchor(null); }
+ }, [graph]);
+
+ const onDoubleClick = useCallback((e: React.MouseEvent) => {
+ if (!graph || !canvasRef.current) return;
+ const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
+ if (node && node.pinned) {
+ node.pinned = false; simAlphaRef.current = Math.max(simAlphaRef.current, 0.3); startAnimLoop();
}
- }, [filteredGraph]);
+ }, [graph, startAnimLoop]);
const closePopover = () => { setSelectedNode(null); setPopoverAnchor(null); };
- // ── Zoom controls ──────────────────────────────────────────────────
const zoomBy = useCallback((factor: number) => {
const vp = vpRef.current;
const cw = canvasSize.w, ch = canvasSize.h;
const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * factor));
- // Zoom toward canvas center
vp.x = cw / 2 - (cw / 2 - vp.x) * (newScale / vp.scale);
vp.y = ch / 2 - (ch / 2 - vp.y) * (newScale / vp.scale);
- vp.scale = newScale;
- setVpScale(newScale);
- redraw();
+ vp.scale = newScale; setVpScale(newScale); redraw();
}, [canvasSize, redraw]);
const resetView = useCallback(() => {
- vpRef.current = { x: 0, y: 0, scale: 1 };
- setVpScale(1);
- redraw();
+ vpRef.current = { x: 0, y: 0, scale: 1 }; setVpScale(1); redraw();
}, [redraw]);
- // Count connections for selected node
- const connectionCount = selectedNode && filteredGraph
- ? filteredGraph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
+ const connectionCount = selectedNode && graph
+ ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
- // ── Render ─────────────────────────────────────────────────────────
+ const connectedNodes = useMemo(() => {
+ if (!selectedNode || !graph) return [];
+ const neighbors: { id: string; type: NodeType; weight: number }[] = [];
+ for (const e of graph.edges) {
+ if (e.source === selectedNode.id) {
+ const n = graph.nodes.find(x => x.id === e.target);
+ if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
+ } else if (e.target === selectedNode.id) {
+ const n = graph.nodes.find(x => x.id === e.source);
+ if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
+ }
+ }
+ return neighbors.sort((a, b) => b.weight - a.weight).slice(0, 12);
+ }, [selectedNode, graph]);
+
+ const hostCount = graph ? graph.nodes.filter(n => n.meta.type === 'host').length : 0;
+ const extCount = graph ? graph.nodes.filter(n => n.meta.type === 'external_ip').length : 0;
+
+ const getCursor = () => {
+ if (dragNode.current) return 'grabbing';
+ if (isPanning.current) return 'grabbing';
+ if (hovered) return 'pointer';
+ return 'grab';
+ };
+ // == Render ==============================================================
return (
- {/* Header row */}
-
- Network Map
-
-
- Hunt
-
-
- setSearch(e.target.value)} sx={{ width: 200 }} />
- }
- onClick={() => loadGraph(selectedHuntId)}
- disabled={loading || !selectedHuntId} size="small">
- Refresh
-
+ {/* Glassmorphism toolbar */}
+
+
+
+
+ Network Map
+
-
+
+
+
+
+ Hunt
+
+
+
+ setSearch(e.target.value)}
+ sx={{ width: 200, '& .MuiInputBase-input': { py: 0.8 } }}
+ slotProps={{
+ input: {
+ startAdornment: ,
+ },
+ }}
+ />
+
+
+
+ loadGraph(selectedHuntId)}
+ disabled={loading || !selectedHuntId}
+ size="small"
+ sx={{ bgcolor: 'rgba(96,165,250,0.1)', '&:hover': { bgcolor: 'rgba(96,165,250,0.2)' } }}
+ >
+
+
+
+
+
+
+ {/* Stats summary cards */}
+ {stats && !loading && (
+
+ {[
+ { label: 'Hosts', value: stats.total_hosts, color: TYPE_COLORS.host },
+ { label: 'With IPs', value: stats.hosts_with_ips, color: '#34d399' },
+ { label: 'With Users', value: stats.hosts_with_users, color: '#a78bfa' },
+ { label: 'Datasets Scanned', value: stats.total_datasets_scanned, color: '#fbbf24' },
+ { label: 'Rows Scanned', value: stats.total_rows_scanned.toLocaleString(), color: '#f87171' },
+ ].map(s => (
+
+
+ {s.label}
+
+
+ {s.value}
+
+
+ ))}
+
+ )}
{/* Loading indicator */}
- {loading && (
-
- {progress}
-
+
+
+
+
+ {progress}
+
+
+
- )}
+
{error && {error}}
- {/* Legend — clickable type filters */}
- {graph && filteredGraph && (
-
-
-
-
-
-
- {([['ip', 'IP'], ['hostname', 'Host'], ['domain', 'Domain'], ['url', 'URL']] as [NodeType, string][]).map(([type, label]) => {
- const active = visibleTypes.has(type);
- const count = graph.nodes.filter(n => n.meta.type === type).length;
- return (
- toggleType(type)}
- sx={{
- bgcolor: active ? TYPE_COLORS[type] : 'transparent',
- color: active ? '#fff' : TYPE_COLORS[type],
- border: `2px solid ${TYPE_COLORS[type]}`,
- fontWeight: 600,
- cursor: 'pointer',
- opacity: active ? 1 : 0.5,
- transition: 'all 0.15s ease',
- '&:hover': { opacity: 1 },
- }}
- />
- );
- })}
-
- )}
-
- {/* Canvas */}
- {filteredGraph && (
-
+ {/* Canvas area */}
+ {graph && (
+
)}
-
{/* Node detail popover */}
{selectedNode && (
-
-
- {selectedNode.label}
-
-
-
-
-
-
- {/* Hostnames */}
- Hostname
-
- {selectedNode.meta.hostnames.length > 0
- ? selectedNode.meta.hostnames.join(', ')
- : Unknown}
-
-
- {/* IPs */}
- IP Address
-
- {selectedNode.meta.ips.length > 0
- ? selectedNode.meta.ips.join(', ')
- : (selectedNode.meta.type === 'ip' ? selectedNode.label : Unknown)}
-
-
- {/* OS */}
- Operating System
-
- {selectedNode.meta.os.length > 0
- ? selectedNode.meta.os.join(', ')
- : Unknown}
-
-
-
-
- {/* Stats */}
-
-
-
-
-
- {/* Datasets */}
- {selectedNode.meta.datasets.length > 0 && (
-
- Seen in datasets
-
- {selectedNode.meta.datasets.map(d => (
-
- ))}
+
+
+
+
+
+
+ {selectedNode.meta.hostname || selectedNode.label}
+
+
-
- )}
+
+
+
+
+
+
+ {selectedNode.meta.fqdn && (
+
+
+ FQDN
+
+
+ {selectedNode.meta.fqdn}
+
+
+ )}
+
+
+
+ IP Address
+
+
+ {selectedNode.meta.ips.length > 0 ? selectedNode.meta.ips.join(', ') : No IP detected}
+
+
+
+
+
+ Operating System
+
+
+ {selectedNode.meta.os || Unknown}
+
+
+
+
+
+ Logged-In Users
+
+ {selectedNode.meta.users.length > 0 ? (
+
+ {selectedNode.meta.users.map(u => (
+
+ ))}
+
+ ) : (
+
+ No user data
+
+ )}
+
+
+ {selectedNode.meta.client_id && (
+
+
+ Client ID
+
+
+ {selectedNode.meta.client_id}
+
+
+ )}
+
+
+
+
+
+
+
+
+
+
+ {connectedNodes.length > 0 && (
+
+
+ Connected To
+
+
+ {connectedNodes.map(cn => (
+ 30 ? cn.id.slice(0, 28) + '\u2026' : cn.id} size="small"
+ sx={{
+ fontSize: 10, height: 22, fontFamily: 'monospace',
+ bgcolor: TYPE_COLORS[cn.type] + '15', color: TYPE_COLORS[cn.type],
+ border: `1px solid ${TYPE_COLORS[cn.type]}33`, cursor: 'pointer',
+ '&:hover': { bgcolor: TYPE_COLORS[cn.type] + '30' },
+ }}
+ onClick={() => { setSearch(cn.id); closePopover(); }}
+ />
+ ))}
+
+
+ )}
+
+ {selectedNode.meta.datasets.length > 0 && (
+
+
+ Seen In Datasets
+
+
+ {selectedNode.meta.datasets.map(d => (
+
+ ))}
+
+
+ )}
+
)}
{/* Empty states */}
{!selectedHuntId && !loading && (
-
-
+
+
+
Select a hunt to visualize its network
-
- Choose a hunt from the dropdown above. The map will display IP addresses,
- hostnames, and domains found across the hunt's datasets, with connections
- showing co-occurrence in the same log rows.
+
+ Choose a hunt from the dropdown above. The map builds a clean,
+ deduplicated host inventory showing each endpoint with its hostname,
+ IP address, OS, and logged-in users.
)}
{selectedHuntId && !graph && !loading && !error && (
-
+
- No network data to display. Upload datasets with IP/hostname columns to this hunt.
+ No host data to display. Upload datasets with host-identifying columns (ClientId, Fqdn, Hostname).
)}
);
-}
+}
\ No newline at end of file
diff --git a/update.md b/update.md
new file mode 100644
index 0000000..c278fe2
--- /dev/null
+++ b/update.md
@@ -0,0 +1,41 @@
+# ThreatHunt Update Log
+
+## 2026-02-20: Host-Centric Network Map & Analysis Platform
+
+### Network Map Overhaul
+- **Problem**: Network Map showed 409 misclassified "domain" nodes (mostly process names like svchost.exe) and 0 hosts. No deduplication same host counted once per dataset.
+- **Root Cause**: IOC column detection misclassified `Fqdn` as "domain" instead of "hostname"; `Name` column (process names) wrongly tagged as "domain" IOC; `ClientId` was in `normalized_columns` as "hostname" but not in `ioc_columns`.
+- **Solution**: Created a new host-centric inventory system that scans all datasets, groups by `Fqdn`/`ClientId`, and extracts IPs, users, OS, and network connections.
+
+#### New Backend Files
+- `backend/app/services/host_inventory.py` Deduplicated host inventory builder. Scans all datasets in a hunt, identifies unique hosts via regex-based column detection (`ClientId`, `Fqdn`, `User`/`Username`, `Laddr.IP`/`Raddr.IP`), groups rows, extracts metadata. Filters system accounts (DWM-*, UMFD-*, LOCAL SERVICE, NETWORK SERVICE). Infers OS from hostname patterns (W10-* Windows 10). Builds network connection graph from netstat remote IPs.
+- `backend/app/api/routes/network.py` `GET /api/network/host-inventory?hunt_id=X` endpoint returning `{hosts, connections, stats}`.
+- `backend/app/services/ioc_extractor.py` IOC extraction service (IP, domain, hash, email, URL patterns).
+- `backend/app/services/anomaly_detector.py` Statistical anomaly detection across datasets.
+- `backend/app/services/data_query.py` Natural language to structured query translation.
+- `backend/app/services/load_balancer.py` Round-robin load balancer for Ollama LLM nodes.
+- `backend/app/services/job_queue.py` Async job queue for long-running analysis tasks.
+- `backend/app/api/routes/analysis.py` 16 analysis endpoints (IOC extraction, anomaly detection, host profiling, triage, reports, job management).
+
+#### Modified Backend Files
+- `backend/app/main.py` Added `network_router` and `analysis_router` includes.
+- `backend/app/db/models.py` Added 4 AI/analysis ORM models (`ProcessingJob`, `AnalysisResult`, `HostProfile`, `IOCEntry`).
+- `backend/app/db/engine.py` Connection pool tuning for SQLite async.
+
+#### Frontend Changes
+- `frontend/src/components/NetworkMap.tsx` Complete rewrite: host-centric force-directed graph using Canvas 2D. Two node types (Host / External IP). Shows hostname, IP, OS in labels. Click popover shows FQDN, IPs, OS, logged-in users, datasets, connections. Search across hostname/IP/user/OS. Stats cards showing host counts.
+- `frontend/src/components/AnalysisDashboard.tsx` New 6-tab analysis dashboard (IOC Extraction, Anomaly Detection, Host Profiling, Query, Triage, Reports).
+- `frontend/src/api/client.ts` Added `network.hostInventory()` method + `InventoryHost`, `InventoryConnection`, `InventoryStats` types. Added analysis API namespace with 16 endpoint methods.
+- `frontend/src/App.tsx` Added Analysis Dashboard route and navigation.
+
+### Results (Radio Hunt 20 Velociraptor datasets, 394K rows)
+
+| Metric | Before | After |
+|--------|--------|-------|
+| Nodes shown | 409 misclassified "domains" | **163 unique hosts** |
+| Hosts identified | 0 | **163** |
+| With IP addresses | N/A | **48** (172.17.x.x LAN) |
+| With logged-in users | N/A | **43** (real names only) |
+| OS detected | None | **Windows 10** (inferred from hostnames) |
+| Deduplication | None (same host 20 datasets) | **Full** (by FQDN/ClientId) |
+| System account filtering | None | **DWM-*, UMFD-*, LOCAL/NETWORK SERVICE removed** |
\ No newline at end of file