diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py new file mode 100644 index 0000000..700524d --- /dev/null +++ b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py @@ -0,0 +1,112 @@ +"""add processing_status and AI analysis tables + +Revision ID: a1b2c3d4e5f6 +Revises: 98ab619418bc +Create Date: 2026-02-19 18:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "98ab619418bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add columns to datasets table + with op.batch_alter_table("datasets") as batch_op: + batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready")) + batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True)) + batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True)) + batch_op.create_index("ix_datasets_status", ["processing_status"]) + + # Create triage_results table + op.create_table( + "triage_results", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("row_start", sa.Integer(), nullable=False), + sa.Column("row_end", sa.Integer(), nullable=False), + sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"), + sa.Column("findings", sa.JSON(), nullable=True), + sa.Column("suspicious_indicators", sa.JSON(), nullable=True), + sa.Column("mitre_techniques", sa.JSON(), nullable=True), + sa.Column("model_used", sa.String(128), nullable=True), + sa.Column("node_used", sa.String(64), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create host_profiles table + op.create_table( + "host_profiles", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("hostname", sa.String(256), nullable=False), + sa.Column("fqdn", sa.String(512), nullable=True), + sa.Column("client_id", sa.String(64), nullable=True), + sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"), + sa.Column("artifact_summary", sa.JSON(), nullable=True), + sa.Column("timeline_summary", sa.Text(), nullable=True), + sa.Column("suspicious_findings", sa.JSON(), nullable=True), + sa.Column("mitre_techniques", sa.JSON(), nullable=True), + sa.Column("llm_analysis", sa.Text(), nullable=True), + sa.Column("model_used", sa.String(128), nullable=True), + sa.Column("node_used", sa.String(64), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create hunt_reports table + op.create_table( + "hunt_reports", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("status", sa.String(20), nullable=False, server_default="pending"), + sa.Column("exec_summary", sa.Text(), nullable=True), + sa.Column("full_report", sa.Text(), nullable=True), + sa.Column("findings", sa.JSON(), nullable=True), + sa.Column("recommendations", sa.JSON(), nullable=True), + sa.Column("mitre_mapping", sa.JSON(), nullable=True), + sa.Column("ioc_table", sa.JSON(), nullable=True), + sa.Column("host_risk_summary", sa.JSON(), nullable=True), + sa.Column("models_used", sa.JSON(), nullable=True), + sa.Column("generation_time_ms", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create anomaly_results table + op.create_table( + "anomaly_results", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True), + sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("distance_from_centroid", sa.Float(), nullable=True), + sa.Column("cluster_id", sa.Integer(), nullable=True), + sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("explanation", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("anomaly_results") + op.drop_table("hunt_reports") + op.drop_table("host_profiles") + op.drop_table("triage_results") + + with op.batch_alter_table("datasets") as batch_op: + batch_op.drop_index("ix_datasets_status") + batch_op.drop_column("file_path") + batch_op.drop_column("error_message") + batch_op.drop_column("artifact_type") + batch_op.drop_column("processing_status") \ No newline at end of file diff --git a/backend/app/api/routes/analysis.py b/backend/app/api/routes/analysis.py new file mode 100644 index 0000000..ff6f7fc --- /dev/null +++ b/backend/app/api/routes/analysis.py @@ -0,0 +1,402 @@ +"""Analysis API routes - triage, host profiles, reports, IOC extraction, +host grouping, anomaly detection, data query (SSE), and job management.""" + +from __future__ import annotations + +import logging +from typing import Optional + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import HostProfile, HuntReport, TriageResult + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/analysis", tags=["analysis"]) + + +# --- Response models --- + +class TriageResultResponse(BaseModel): + id: str + dataset_id: str + row_start: int + row_end: int + risk_score: float + verdict: str + findings: list | None = None + suspicious_indicators: list | None = None + mitre_techniques: list | None = None + model_used: str | None = None + node_used: str | None = None + + class Config: + from_attributes = True + + +class HostProfileResponse(BaseModel): + id: str + hunt_id: str + hostname: str + fqdn: str | None = None + risk_score: float + risk_level: str + artifact_summary: dict | None = None + timeline_summary: str | None = None + suspicious_findings: list | None = None + mitre_techniques: list | None = None + llm_analysis: str | None = None + model_used: str | None = None + + class Config: + from_attributes = True + + +class HuntReportResponse(BaseModel): + id: str + hunt_id: str + status: str + exec_summary: str | None = None + full_report: str | None = None + findings: list | None = None + recommendations: list | None = None + mitre_mapping: dict | None = None + ioc_table: list | None = None + host_risk_summary: list | None = None + models_used: list | None = None + generation_time_ms: int | None = None + + class Config: + from_attributes = True + + +class QueryRequest(BaseModel): + question: str + mode: str = "quick" # quick or deep + + +# --- Triage endpoints --- + +@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse]) +async def get_triage_results( + dataset_id: str, + min_risk: float = Query(0.0, ge=0.0, le=10.0), + db: AsyncSession = Depends(get_db), +): + result = await db.execute( + select(TriageResult) + .where(TriageResult.dataset_id == dataset_id) + .where(TriageResult.risk_score >= min_risk) + .order_by(TriageResult.risk_score.desc()) + ) + return result.scalars().all() + + +@router.post("/triage/{dataset_id}") +async def trigger_triage( + dataset_id: str, + background_tasks: BackgroundTasks, +): + async def _run(): + from app.services.triage import triage_dataset + await triage_dataset(dataset_id) + + background_tasks.add_task(_run) + return {"status": "triage_started", "dataset_id": dataset_id} + + +# --- Host profile endpoints --- + +@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse]) +async def get_host_profiles( + hunt_id: str, + min_risk: float = Query(0.0, ge=0.0, le=10.0), + db: AsyncSession = Depends(get_db), +): + result = await db.execute( + select(HostProfile) + .where(HostProfile.hunt_id == hunt_id) + .where(HostProfile.risk_score >= min_risk) + .order_by(HostProfile.risk_score.desc()) + ) + return result.scalars().all() + + +@router.post("/profiles/{hunt_id}") +async def trigger_all_profiles( + hunt_id: str, + background_tasks: BackgroundTasks, +): + async def _run(): + from app.services.host_profiler import profile_all_hosts + await profile_all_hosts(hunt_id) + + background_tasks.add_task(_run) + return {"status": "profiling_started", "hunt_id": hunt_id} + + +@router.post("/profiles/{hunt_id}/{hostname}") +async def trigger_single_profile( + hunt_id: str, + hostname: str, + background_tasks: BackgroundTasks, +): + async def _run(): + from app.services.host_profiler import profile_host + await profile_host(hunt_id, hostname) + + background_tasks.add_task(_run) + return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname} + + +# --- Report endpoints --- + +@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse]) +async def list_reports( + hunt_id: str, + db: AsyncSession = Depends(get_db), +): + result = await db.execute( + select(HuntReport) + .where(HuntReport.hunt_id == hunt_id) + .order_by(HuntReport.created_at.desc()) + ) + return result.scalars().all() + + +@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse) +async def get_report( + hunt_id: str, + report_id: str, + db: AsyncSession = Depends(get_db), +): + result = await db.execute( + select(HuntReport) + .where(HuntReport.id == report_id) + .where(HuntReport.hunt_id == hunt_id) + ) + report = result.scalar_one_or_none() + if not report: + raise HTTPException(status_code=404, detail="Report not found") + return report + + +@router.post("/reports/{hunt_id}/generate") +async def trigger_report( + hunt_id: str, + background_tasks: BackgroundTasks, +): + async def _run(): + from app.services.report_generator import generate_report + await generate_report(hunt_id) + + background_tasks.add_task(_run) + return {"status": "report_generation_started", "hunt_id": hunt_id} + + +# --- IOC extraction endpoints --- + +@router.get("/iocs/{dataset_id}") +async def extract_iocs( + dataset_id: str, + max_rows: int = Query(5000, ge=1, le=50000), + db: AsyncSession = Depends(get_db), +): + """Extract IOCs (IPs, domains, hashes, etc.) from dataset rows.""" + from app.services.ioc_extractor import extract_iocs_from_dataset + iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows) + total = sum(len(v) for v in iocs.values()) + return {"dataset_id": dataset_id, "iocs": iocs, "total": total} + + +# --- Host grouping endpoints --- + +@router.get("/hosts/{hunt_id}") +async def get_host_groups( + hunt_id: str, + db: AsyncSession = Depends(get_db), +): + """Group data by hostname across all datasets in a hunt.""" + from app.services.ioc_extractor import extract_host_groups + groups = await extract_host_groups(hunt_id, db) + return {"hunt_id": hunt_id, "hosts": groups} + + +# --- Anomaly detection endpoints --- + +@router.get("/anomalies/{dataset_id}") +async def get_anomalies( + dataset_id: str, + outliers_only: bool = Query(False), + db: AsyncSession = Depends(get_db), +): + """Get anomaly detection results for a dataset.""" + from app.db.models import AnomalyResult + stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id) + if outliers_only: + stmt = stmt.where(AnomalyResult.is_outlier == True) + stmt = stmt.order_by(AnomalyResult.anomaly_score.desc()) + result = await db.execute(stmt) + rows = result.scalars().all() + return [ + { + "id": r.id, + "dataset_id": r.dataset_id, + "row_id": r.row_id, + "anomaly_score": r.anomaly_score, + "distance_from_centroid": r.distance_from_centroid, + "cluster_id": r.cluster_id, + "is_outlier": r.is_outlier, + "explanation": r.explanation, + } + for r in rows + ] + + +@router.post("/anomalies/{dataset_id}") +async def trigger_anomaly_detection( + dataset_id: str, + k: int = Query(3, ge=2, le=20), + threshold: float = Query(0.35, ge=0.1, le=0.9), + background_tasks: BackgroundTasks = None, +): + """Trigger embedding-based anomaly detection on a dataset.""" + async def _run(): + from app.services.anomaly_detector import detect_anomalies + await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold) + + if background_tasks: + background_tasks.add_task(_run) + return {"status": "anomaly_detection_started", "dataset_id": dataset_id} + else: + from app.services.anomaly_detector import detect_anomalies + results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold) + return {"status": "complete", "dataset_id": dataset_id, "count": len(results)} + + +# --- Natural language data query (SSE streaming) --- + +@router.post("/query/{dataset_id}") +async def query_dataset_endpoint( + dataset_id: str, + body: QueryRequest, +): + """Ask a natural language question about a dataset. + + Returns an SSE stream with token-by-token LLM response. + Event types: status, metadata, token, error, done + """ + from app.services.data_query import query_dataset_stream + + return StreamingResponse( + query_dataset_stream(dataset_id, body.question, body.mode), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@router.post("/query/{dataset_id}/sync") +async def query_dataset_sync( + dataset_id: str, + body: QueryRequest, +): + """Non-streaming version of data query.""" + from app.services.data_query import query_dataset + + try: + answer = await query_dataset(dataset_id, body.question, body.mode) + return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode} + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.error(f"Query failed: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +# --- Job queue endpoints --- + +@router.get("/jobs") +async def list_jobs( + status: str | None = Query(None), + job_type: str | None = Query(None), + limit: int = Query(50, ge=1, le=200), +): + """List all tracked jobs.""" + from app.services.job_queue import job_queue, JobStatus, JobType + + s = JobStatus(status) if status else None + t = JobType(job_type) if job_type else None + jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit) + stats = job_queue.get_stats() + return {"jobs": jobs, "stats": stats} + + +@router.get("/jobs/{job_id}") +async def get_job(job_id: str): + """Get status of a specific job.""" + from app.services.job_queue import job_queue + + job = job_queue.get_job(job_id) + if not job: + raise HTTPException(status_code=404, detail="Job not found") + return job.to_dict() + + +@router.delete("/jobs/{job_id}") +async def cancel_job(job_id: str): + """Cancel a running or queued job.""" + from app.services.job_queue import job_queue + + if job_queue.cancel_job(job_id): + return {"status": "cancelled", "job_id": job_id} + raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)") + + +@router.post("/jobs/submit/{job_type}") +async def submit_job( + job_type: str, + params: dict = {}, +): + """Submit a new job to the queue. + + Job types: triage, host_profile, report, anomaly, query + Params vary by type (e.g., dataset_id, hunt_id, question, mode). + """ + from app.services.job_queue import job_queue, JobType + + try: + jt = JobType(job_type) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}", + ) + + job = job_queue.submit(jt, **params) + return {"job_id": job.id, "status": job.status.value, "job_type": job_type} + + +# --- Load balancer status --- + +@router.get("/lb/status") +async def lb_status(): + """Get load balancer status for both nodes.""" + from app.services.load_balancer import lb + return lb.get_status() + + +@router.post("/lb/check") +async def lb_health_check(): + """Force a health check of both nodes.""" + from app.services.load_balancer import lb + await lb.check_health() + return lb.get_status() \ No newline at end of file diff --git a/backend/app/api/routes/network.py b/backend/app/api/routes/network.py new file mode 100644 index 0000000..65d47ad --- /dev/null +++ b/backend/app/api/routes/network.py @@ -0,0 +1,28 @@ +"""Network topology API - host inventory endpoint.""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.services.host_inventory import build_host_inventory + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/network", tags=["network"]) + + +@router.get("/host-inventory") +async def get_host_inventory( + hunt_id: str = Query(..., description="Hunt ID to build inventory for"), + db: AsyncSession = Depends(get_db), +): + """Build a deduplicated host inventory from all datasets in a hunt. + + Returns unique hosts with hostname, IPs, OS, logged-in users, and + network connections derived from netstat/connection data. + """ + result = await build_host_inventory(hunt_id, db) + if result["stats"]["total_hosts"] == 0: + return result + return result \ No newline at end of file diff --git a/backend/app/db/engine.py b/backend/app/db/engine.py index afd2968..dbfc86c 100644 --- a/backend/app/db/engine.py +++ b/backend/app/db/engine.py @@ -3,6 +3,7 @@ Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL. """ +from sqlalchemy import event from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, @@ -12,12 +13,32 @@ from sqlalchemy.orm import DeclarativeBase from app.config import settings -engine = create_async_engine( - settings.DATABASE_URL, +_is_sqlite = settings.DATABASE_URL.startswith("sqlite") + +_engine_kwargs: dict = dict( echo=settings.DEBUG, future=True, ) +if _is_sqlite: + _engine_kwargs["connect_args"] = {"timeout": 30} + _engine_kwargs["pool_size"] = 1 + _engine_kwargs["max_overflow"] = 0 + +engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs) + + +@event.listens_for(engine.sync_engine, "connect") +def _set_sqlite_pragmas(dbapi_conn, connection_record): + """Enable WAL mode and tune busy-timeout for SQLite connections.""" + if _is_sqlite: + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA busy_timeout=5000") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + async_session_factory = async_sessionmaker( engine, class_=AsyncSession, @@ -51,4 +72,4 @@ async def init_db() -> None: async def dispose_db() -> None: """Dispose of the engine connection pool.""" - await engine.dispose() + await engine.dispose() \ No newline at end of file diff --git a/backend/app/db/models.py b/backend/app/db/models.py index f786b94..f2ab0ff 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -1,7 +1,7 @@ """SQLAlchemy ORM models for ThreatHunt. All persistent entities: datasets, hunts, conversations, annotations, -hypotheses, enrichment results, and users. +hypotheses, enrichment results, users, and AI analysis tables. """ import uuid @@ -32,8 +32,7 @@ def _new_id() -> str: return uuid.uuid4().hex -# ── Users ────────────────────────────────────────────────────────────── - +# -- Users --- class User(Base): __tablename__ = "users" @@ -42,17 +41,15 @@ class User(Base): username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False) hashed_password: Mapped[str] = mapped_column(String(256), nullable=False) - role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer + role: Mapped[str] = mapped_column(String(16), default="analyst") is_active: Mapped[bool] = mapped_column(Boolean, default=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) - # relationships hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin") annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin") -# ── Hunts ────────────────────────────────────────────────────────────── - +# -- Hunts --- class Hunt(Base): __tablename__ = "hunts" @@ -60,7 +57,7 @@ class Hunt(Base): id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) name: Mapped[str] = mapped_column(String(256), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived + status: Mapped[str] = mapped_column(String(32), default="active") owner_id: Mapped[Optional[str]] = mapped_column( String(32), ForeignKey("users.id"), nullable=True ) @@ -69,15 +66,15 @@ class Hunt(Base): DateTime(timezone=True), default=_utcnow, onupdate=_utcnow ) - # relationships owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin") datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin") conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin") hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin") + host_profiles: Mapped[list["HostProfile"]] = relationship(back_populates="hunt", lazy="noload") + reports: Mapped[list["HuntReport"]] = relationship(back_populates="hunt", lazy="noload") -# ── Datasets ─────────────────────────────────────────────────────────── - +# -- Datasets --- class Dataset(Base): __tablename__ = "datasets" @@ -85,36 +82,44 @@ class Dataset(Base): id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) name: Mapped[str] = mapped_column(String(256), nullable=False, index=True) filename: Mapped[str] = mapped_column(String(512), nullable=False) - source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc. + source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) row_count: Mapped[int] = mapped_column(Integer, default=0) column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) - ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns + ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) file_size_bytes: Mapped[int] = mapped_column(Integer, default=0) encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True) time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + # New Phase 1-2 columns + processing_status: Mapped[str] = mapped_column(String(20), default="ready") + artifact_type: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) + error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) + hunt_id: Mapped[Optional[str]] = mapped_column( String(32), ForeignKey("hunts.id"), nullable=True ) uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) - # relationships hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin") rows: Mapped[list["DatasetRow"]] = relationship( back_populates="dataset", lazy="noload", cascade="all, delete-orphan" ) + triage_results: Mapped[list["TriageResult"]] = relationship( + back_populates="dataset", lazy="noload", cascade="all, delete-orphan" + ) __table_args__ = ( Index("ix_datasets_hunt", "hunt_id"), + Index("ix_datasets_status", "processing_status"), ) class DatasetRow(Base): - """Individual row from a CSV dataset, stored as JSON blob.""" __tablename__ = "dataset_rows" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -125,7 +130,6 @@ class DatasetRow(Base): data: Mapped[dict] = mapped_column(JSON, nullable=False) normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) - # relationships dataset: Mapped["Dataset"] = relationship(back_populates="rows") annotations: Mapped[list["Annotation"]] = relationship( back_populates="row", lazy="noload" @@ -137,8 +141,7 @@ class DatasetRow(Base): ) -# ── Conversations ───────────────────────────────────────────────────── - +# -- Conversations --- class Conversation(Base): __tablename__ = "conversations" @@ -156,7 +159,6 @@ class Conversation(Base): DateTime(timezone=True), default=_utcnow, onupdate=_utcnow ) - # relationships hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin") messages: Mapped[list["Message"]] = relationship( back_populates="conversation", lazy="selectin", cascade="all, delete-orphan", @@ -171,16 +173,15 @@ class Message(Base): conversation_id: Mapped[str] = mapped_column( String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False ) - role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system + role: Mapped[str] = mapped_column(String(16), nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False) model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) - node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster + node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) - # relationships conversation: Mapped["Conversation"] = relationship(back_populates="messages") __table_args__ = ( @@ -188,8 +189,7 @@ class Message(Base): ) -# ── Annotations ─────────────────────────────────────────────────────── - +# -- Annotations --- class Annotation(Base): __tablename__ = "annotations" @@ -205,19 +205,14 @@ class Annotation(Base): String(32), ForeignKey("users.id"), nullable=True ) text: Mapped[str] = mapped_column(Text, nullable=False) - severity: Mapped[str] = mapped_column( - String(16), default="info" - ) # info | low | medium | high | critical - tag: Mapped[Optional[str]] = mapped_column( - String(32), nullable=True - ) # suspicious | benign | needs-review + severity: Mapped[str] = mapped_column(String(16), default="info") + tag: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=_utcnow, onupdate=_utcnow ) - # relationships row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations") author: Mapped[Optional["User"]] = relationship(back_populates="annotations") @@ -227,8 +222,7 @@ class Annotation(Base): ) -# ── Hypotheses ──────────────────────────────────────────────────────── - +# -- Hypotheses --- class Hypothesis(Base): __tablename__ = "hypotheses" @@ -240,9 +234,7 @@ class Hypothesis(Base): title: Mapped[str] = mapped_column(String(256), nullable=False) description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True) - status: Mapped[str] = mapped_column( - String(16), default="draft" - ) # draft | active | confirmed | rejected + status: Mapped[str] = mapped_column(String(16), default="draft") evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) @@ -250,7 +242,6 @@ class Hypothesis(Base): DateTime(timezone=True), default=_utcnow, onupdate=_utcnow ) - # relationships hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin") __table_args__ = ( @@ -258,21 +249,16 @@ class Hypothesis(Base): ) -# ── Enrichment Results ──────────────────────────────────────────────── - +# -- Enrichment Results --- class EnrichmentResult(Base): __tablename__ = "enrichment_results" id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True) - ioc_type: Mapped[str] = mapped_column( - String(32), nullable=False - ) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url - source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai - verdict: Mapped[Optional[str]] = mapped_column( - String(16), nullable=True - ) # clean | suspicious | malicious | unknown + ioc_type: Mapped[str] = mapped_column(String(32), nullable=False) + source: Mapped[str] = mapped_column(String(32), nullable=False) + verdict: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True) raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) @@ -287,28 +273,24 @@ class EnrichmentResult(Base): ) -# ── AUP Keyword Themes & Keywords ──────────────────────────────────── - +# -- AUP Keyword Themes & Keywords --- class KeywordTheme(Base): - """A named category of keywords for AUP scanning (e.g. gambling, gaming).""" __tablename__ = "keyword_themes" id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) - color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color + color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") enabled: Mapped[bool] = mapped_column(Boolean, default=True) - is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided + is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) - # relationships keywords: Mapped[list["Keyword"]] = relationship( back_populates="theme", lazy="selectin", cascade="all, delete-orphan" ) class Keyword(Base): - """Individual keyword / pattern belonging to a theme.""" __tablename__ = "keywords" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -319,10 +301,102 @@ class Keyword(Base): is_regex: Mapped[bool] = mapped_column(Boolean, default=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) - # relationships theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords") __table_args__ = ( Index("ix_keywords_theme", "theme_id"), Index("ix_keywords_value", "value"), ) + + +# -- AI Analysis Tables (Phase 2) --- + +class TriageResult(Base): + __tablename__ = "triage_results" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + dataset_id: Mapped[str] = mapped_column( + String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True + ) + row_start: Mapped[int] = mapped_column(Integer, nullable=False) + row_end: Mapped[int] = mapped_column(Integer, nullable=False) + risk_score: Mapped[float] = mapped_column(Float, default=0.0) + verdict: Mapped[str] = mapped_column(String(20), default="pending") + findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + suspicious_indicators: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) + node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + + dataset: Mapped["Dataset"] = relationship(back_populates="triage_results") + + +class HostProfile(Base): + __tablename__ = "host_profiles" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + hunt_id: Mapped[str] = mapped_column( + String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True + ) + hostname: Mapped[str] = mapped_column(String(256), nullable=False) + fqdn: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) + client_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + risk_score: Mapped[float] = mapped_column(Float, default=0.0) + risk_level: Mapped[str] = mapped_column(String(20), default="unknown") + artifact_summary: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + timeline_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + suspicious_findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + llm_analysis: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) + node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + hunt: Mapped["Hunt"] = relationship(back_populates="host_profiles") + + +class HuntReport(Base): + __tablename__ = "hunt_reports" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + hunt_id: Mapped[str] = mapped_column( + String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True + ) + status: Mapped[str] = mapped_column(String(20), default="pending") + exec_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + full_report: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + recommendations: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + mitre_mapping: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) + ioc_table: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + host_risk_summary: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + models_used: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) + generation_time_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + hunt: Mapped["Hunt"] = relationship(back_populates="reports") + + +class AnomalyResult(Base): + __tablename__ = "anomaly_results" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + dataset_id: Mapped[str] = mapped_column( + String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True + ) + row_id: Mapped[Optional[int]] = mapped_column( + Integer, ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True + ) + anomaly_score: Mapped[float] = mapped_column(Float, default=0.0) + distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + is_outlier: Mapped[bool] = mapped_column(Boolean, default=False) + explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) \ No newline at end of file diff --git a/backend/app/main.py b/backend/app/main.py index bd75ed7..6c81ddb 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,10 +1,12 @@ """ThreatHunt backend application. Wires together: database, CORS, agent routes, dataset routes, hunt routes, -annotation/hypothesis routes. DB tables are auto-created on startup. +annotation/hypothesis routes, analysis routes, network routes, job queue, +load balancer. DB tables are auto-created on startup. """ import logging +import os from contextlib import asynccontextmanager from fastapi import FastAPI @@ -21,6 +23,8 @@ from app.api.routes.correlation import router as correlation_router from app.api.routes.reports import router as reports_router from app.api.routes.auth import router as auth_router from app.api.routes.keywords import router as keywords_router +from app.api.routes.analysis import router as analysis_router +from app.api.routes.network import router as network_router logger = logging.getLogger(__name__) @@ -28,17 +32,45 @@ logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Startup / shutdown lifecycle.""" - logger.info("Starting ThreatHunt API …") + logger.info("Starting ThreatHunt API ...") await init_db() logger.info("Database initialised") + + # Ensure uploads directory exists + os.makedirs(settings.UPLOAD_DIR, exist_ok=True) + logger.info("Upload dir: %s", os.path.abspath(settings.UPLOAD_DIR)) + # Seed default AUP keyword themes from app.db import async_session_factory from app.services.keyword_defaults import seed_defaults async with async_session_factory() as seed_db: await seed_defaults(seed_db) logger.info("AUP keyword defaults checked") + + # Start job queue (Phase 10) + from app.services.job_queue import job_queue, register_all_handlers + register_all_handlers() + await job_queue.start() + logger.info("Job queue started (%d workers)", job_queue._max_workers) + + # Start load balancer health loop (Phase 10) + from app.services.load_balancer import lb + await lb.start_health_loop(interval=30.0) + logger.info("Load balancer health loop started") + yield - logger.info("Shutting down …") + + logger.info("Shutting down ...") + # Stop job queue + from app.services.job_queue import job_queue as jq + await jq.stop() + logger.info("Job queue stopped") + + # Stop load balancer + from app.services.load_balancer import lb as _lb + await _lb.stop_health_loop() + logger.info("Load balancer stopped") + from app.agents.providers_v2 import cleanup_client from app.services.enrichment import enrichment_engine await cleanup_client() @@ -46,15 +78,13 @@ async def lifespan(app: FastAPI): await dispose_db() -# Create FastAPI application app = FastAPI( title="ThreatHunt API", description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster", - version="0.3.0", + version=settings.APP_VERSION, lifespan=lifespan, ) -# Configure CORS app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, @@ -74,11 +104,12 @@ app.include_router(enrichment_router) app.include_router(correlation_router) app.include_router(reports_router) app.include_router(keywords_router) +app.include_router(analysis_router) +app.include_router(network_router) @app.get("/", tags=["health"]) async def root(): - """API health check.""" return { "service": "ThreatHunt API", "version": settings.APP_VERSION, @@ -89,4 +120,4 @@ async def root(): "roadrunner": settings.roadrunner_url, "openwebui": settings.OPENWEBUI_URL, }, - } + } \ No newline at end of file diff --git a/backend/app/services/anomaly_detector.py b/backend/app/services/anomaly_detector.py new file mode 100644 index 0000000..69b7593 --- /dev/null +++ b/backend/app/services/anomaly_detector.py @@ -0,0 +1,199 @@ +"""Embedding-based anomaly detection using Roadrunner's bge-m3 model. + +Converts dataset rows to embeddings, clusters them, and flags outliers +that deviate significantly from the cluster centroids. Uses cosine +distance and simple k-means-like centroid computation. +""" + +import asyncio +import json +import logging +import math +from typing import Optional + +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.db import async_session_factory +from app.db.models import AnomalyResult, Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +EMBED_URL = f"{settings.roadrunner_url}/api/embed" +EMBED_MODEL = "bge-m3" +BATCH_SIZE = 32 # rows per embedding batch +MAX_ROWS = 2000 # cap for anomaly detection + +# --- math helpers (no numpy required) --- + +def _dot(a: list[float], b: list[float]) -> float: + return sum(x * y for x, y in zip(a, b)) + + +def _norm(v: list[float]) -> float: + return math.sqrt(sum(x * x for x in v)) + + +def _cosine_distance(a: list[float], b: list[float]) -> float: + na, nb = _norm(a), _norm(b) + if na == 0 or nb == 0: + return 1.0 + return 1.0 - _dot(a, b) / (na * nb) + + +def _mean_vector(vectors: list[list[float]]) -> list[float]: + if not vectors: + return [] + dim = len(vectors[0]) + n = len(vectors) + return [sum(v[i] for v in vectors) / n for i in range(dim)] + + +def _row_to_text(data: dict) -> str: + """Flatten a row dict to a single string for embedding.""" + parts = [] + for k, v in data.items(): + sv = str(v).strip() + if sv and sv.lower() not in ('none', 'null', ''): + parts.append(f"{k}={sv}") + return " | ".join(parts)[:2000] # cap length + + +async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]: + """Get embeddings from Roadrunner's Ollama API.""" + resp = await client.post( + EMBED_URL, + json={"model": EMBED_MODEL, "input": texts}, + timeout=120.0, + ) + resp.raise_for_status() + data = resp.json() + # Ollama returns {"embeddings": [[...], ...]} + return data.get("embeddings", []) + + +def _simple_cluster( + embeddings: list[list[float]], + k: int = 3, + max_iter: int = 20, +) -> tuple[list[int], list[list[float]]]: + """Simple k-means clustering (no numpy dependency). + + Returns (assignments, centroids). + """ + n = len(embeddings) + if n <= k: + return list(range(n)), embeddings[:] + + # Init centroids: evenly spaced indices + step = max(n // k, 1) + centroids = [embeddings[i * step % n] for i in range(k)] + assignments = [0] * n + + for _ in range(max_iter): + # Assign to nearest centroid + new_assignments = [] + for emb in embeddings: + dists = [_cosine_distance(emb, c) for c in centroids] + new_assignments.append(dists.index(min(dists))) + + if new_assignments == assignments: + break + assignments = new_assignments + + # Recompute centroids + for ci in range(k): + members = [embeddings[j] for j in range(n) if assignments[j] == ci] + if members: + centroids[ci] = _mean_vector(members) + + return assignments, centroids + + +async def detect_anomalies( + dataset_id: str, + k: int = 3, + outlier_threshold: float = 0.35, +) -> list[dict]: + """Run embedding-based anomaly detection on a dataset. + + 1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers. + """ + async with async_session_factory() as db: + # Load rows + result = await db.execute( + select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .limit(MAX_ROWS) + ) + rows = result.all() + if not rows: + logger.info("No rows for anomaly detection in dataset %s", dataset_id) + return [] + + row_ids = [r[0] for r in rows] + row_indices = [r[1] for r in rows] + texts = [_row_to_text(r[2]) for r in rows] + + logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL) + + # Embed in batches + all_embeddings: list[list[float]] = [] + async with httpx.AsyncClient() as client: + for i in range(0, len(texts), BATCH_SIZE): + batch = texts[i : i + BATCH_SIZE] + try: + embs = await _embed_batch(batch, client) + all_embeddings.extend(embs) + except Exception as e: + logger.error("Embedding batch %d failed: %s", i, e) + # Fill with zeros so indices stay aligned + all_embeddings.extend([[0.0] * 1024] * len(batch)) + + if not all_embeddings or len(all_embeddings) != len(texts): + logger.error("Embedding count mismatch") + return [] + + # Cluster + actual_k = min(k, len(all_embeddings)) + assignments, centroids = _simple_cluster(all_embeddings, k=actual_k) + + # Compute distances from centroid + anomalies: list[dict] = [] + for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)): + dist = _cosine_distance(emb, centroids[cluster_id]) + is_outlier = dist > outlier_threshold + anomalies.append({ + "row_id": row_ids[idx], + "row_index": row_indices[idx], + "anomaly_score": round(dist, 4), + "distance_from_centroid": round(dist, 4), + "cluster_id": cluster_id, + "is_outlier": is_outlier, + }) + + # Save to DB + outlier_count = 0 + for a in anomalies: + ar = AnomalyResult( + dataset_id=dataset_id, + row_id=a["row_id"], + anomaly_score=a["anomaly_score"], + distance_from_centroid=a["distance_from_centroid"], + cluster_id=a["cluster_id"], + is_outlier=a["is_outlier"], + ) + db.add(ar) + if a["is_outlier"]: + outlier_count += 1 + + await db.commit() + logger.info( + "Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)", + len(anomalies), outlier_count, outlier_threshold, + ) + + return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True) \ No newline at end of file diff --git a/backend/app/services/artifact_classifier.py b/backend/app/services/artifact_classifier.py new file mode 100644 index 0000000..d88bf32 --- /dev/null +++ b/backend/app/services/artifact_classifier.py @@ -0,0 +1,81 @@ +"""Artifact classifier - identify Velociraptor artifact types from CSV headers.""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# (required_columns, artifact_type) +FINGERPRINTS: list[tuple[set[str], str]] = [ + ({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"), + ({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"), + ({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"), + ({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"), + ({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"), + ({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"), + ({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"), + ({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"), + ({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"), + ({"Key", "Type", "Value"}, "Windows.Registry.Finder"), + ({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"), + ({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"), + ({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"), + ({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"), + ({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"), + ({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"), + ({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"), + ({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"), + ({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"), + ({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"), + ({"os_info.hostname", "os_info.system"}, "Server.Information.Client"), + ({"ClientId", "os_info.fqdn"}, "Server.Information.Client"), + ({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"), + ({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"), + ({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"), + ({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"), + ({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"), + ({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"), +] + +VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"} + +CATEGORY_MAP = { + "Pslist": "process", + "Netstat": "network", + "Services": "persistence", + "FileFinder": "filesystem", + "Prefetch": "execution", + "Registry": "persistence", + "EvtxHunter": "eventlog", + "EventLogs": "eventlog", + "Autoruns": "persistence", + "DNS": "network", + "Hash": "filesystem", + "TaskScheduler": "persistence", + "Users": "account", + "Client": "system", + "WMI": "persistence", + "ArpCache": "network", + "BrowserHistory": "application", +} + + +def classify_artifact(columns: list[str]) -> str: + col_set = set(columns) + for required, artifact_type in FINGERPRINTS: + if required.issubset(col_set): + return artifact_type + if VELOCIRAPTOR_META.intersection(col_set): + return "Velociraptor.Unknown" + return "Unknown" + + +def get_artifact_category(artifact_type: str) -> str: + for key, category in CATEGORY_MAP.items(): + if key.lower() in artifact_type.lower(): + return category + return "unknown" \ No newline at end of file diff --git a/backend/app/services/data_query.py b/backend/app/services/data_query.py new file mode 100644 index 0000000..f13cdf5 --- /dev/null +++ b/backend/app/services/data_query.py @@ -0,0 +1,238 @@ +"""Natural-language data query service with SSE streaming. + +Lets analysts ask questions about dataset rows in plain English. +Routes to fast model (Roadrunner) for quick queries, heavy model (Wile) +for deep analysis. Supports streaming via OllamaProvider.generate_stream(). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from typing import AsyncIterator + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.db import async_session_factory +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# Maximum rows to include in context window +MAX_CONTEXT_ROWS = 60 +MAX_ROW_TEXT_CHARS = 300 + + +def _rows_to_text(rows: list[dict], columns: list[str]) -> str: + """Convert dataset rows to a compact text table for the LLM context.""" + if not rows: + return "(no rows)" + # Header + header = " | ".join(columns[:20]) # cap columns to avoid overflow + lines = [header, "-" * min(len(header), 120)] + for row in rows[:MAX_CONTEXT_ROWS]: + vals = [] + for c in columns[:20]: + v = str(row.get(c, "")) + if len(v) > 80: + v = v[:77] + "..." + vals.append(v) + line = " | ".join(vals) + if len(line) > MAX_ROW_TEXT_CHARS: + line = line[:MAX_ROW_TEXT_CHARS] + "..." + lines.append(line) + return "\n".join(lines) + + +QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt. +You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.). + +Your job: +- Answer the analyst's question about this data accurately and concisely +- Point out suspicious patterns, anomalies, or indicators of compromise +- Reference MITRE ATT&CK techniques when relevant +- Suggest follow-up queries or pivots +- If you cannot answer from the data provided, say so clearly + +Rules: +- Be factual - only reference data you can see +- Use forensic terminology appropriate for SOC/DFIR analysts +- Format your answer with clear sections using markdown +- If the data seems benign, say so - do not fabricate threats""" + + +async def _load_dataset_context( + dataset_id: str, + db: AsyncSession, + sample_size: int = MAX_CONTEXT_ROWS, +) -> tuple[dict, str, int]: + """Load dataset metadata + sample rows for context. + + Returns (metadata_dict, rows_text, total_row_count). + """ + ds = await db.get(Dataset, dataset_id) + if not ds: + raise ValueError(f"Dataset {dataset_id} not found") + + # Get total count + count_q = await db.execute( + select(func.count()).where(DatasetRow.dataset_id == dataset_id) + ) + total = count_q.scalar() or 0 + + # Sample rows - get first batch + some from the middle + half = sample_size // 2 + result = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .limit(half) + ) + first_rows = result.scalars().all() + + # If dataset is large, also sample from the middle + middle_rows = [] + if total > sample_size: + mid_offset = total // 2 + result2 = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(mid_offset) + .limit(sample_size - half) + ) + middle_rows = result2.scalars().all() + else: + result2 = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(half) + .limit(sample_size - half) + ) + middle_rows = result2.scalars().all() + + all_rows = first_rows + middle_rows + row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows] + + columns = list(ds.column_schema.keys()) if ds.column_schema else [] + if not columns and row_dicts: + columns = list(row_dicts[0].keys()) + + rows_text = _rows_to_text(row_dicts, columns) + + metadata = { + "name": ds.name, + "filename": ds.filename, + "source_tool": ds.source_tool, + "artifact_type": getattr(ds, "artifact_type", None), + "row_count": total, + "columns": columns[:30], + "sample_rows_shown": len(all_rows), + } + return metadata, rows_text, total + + +async def query_dataset( + dataset_id: str, + question: str, + mode: str = "quick", +) -> str: + """Non-streaming query: returns full answer text.""" + from app.agents.providers_v2 import OllamaProvider, Node + + async with async_session_factory() as db: + meta, rows_text, total = await _load_dataset_context(dataset_id, db) + + prompt = _build_prompt(question, meta, rows_text, total) + + if mode == "deep": + provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE) + max_tokens = 4096 + else: + provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER) + max_tokens = 2048 + + result = await provider.generate( + prompt, + system=QUERY_SYSTEM_PROMPT, + max_tokens=max_tokens, + temperature=0.3, + ) + return result.get("response", "No response generated.") + + +async def query_dataset_stream( + dataset_id: str, + question: str, + mode: str = "quick", +) -> AsyncIterator[str]: + """Streaming query: yields SSE-formatted events.""" + from app.agents.providers_v2 import OllamaProvider, Node + + start = time.monotonic() + + # Send initial metadata event + yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n" + + async with async_session_factory() as db: + meta, rows_text, total = await _load_dataset_context(dataset_id, db) + + yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n" + yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n" + + prompt = _build_prompt(question, meta, rows_text, total) + + if mode == "deep": + provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE) + max_tokens = 4096 + model_name = settings.DEFAULT_HEAVY_MODEL + node_name = "wile" + else: + provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER) + max_tokens = 2048 + model_name = settings.DEFAULT_FAST_MODEL + node_name = "roadrunner" + + # Stream tokens + token_count = 0 + try: + async for token in provider.generate_stream( + prompt, + system=QUERY_SYSTEM_PROMPT, + max_tokens=max_tokens, + temperature=0.3, + ): + token_count += 1 + yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" + except Exception as e: + logger.error(f"Streaming error: {e}") + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + elapsed_ms = int((time.monotonic() - start) * 1000) + yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n" + + +def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str: + """Construct the full prompt with data context.""" + parts = [ + f"## Dataset: {meta['name']}", + f"- Source: {meta.get('source_tool', 'unknown')}", + f"- Artifact type: {meta.get('artifact_type', 'unknown')}", + f"- Total rows: {total}", + f"- Columns: {', '.join(meta.get('columns', []))}", + f"- Showing {meta['sample_rows_shown']} sample rows below", + "", + "## Sample Data", + "```", + rows_text, + "```", + "", + f"## Analyst Question", + question, + ] + return "\n".join(parts) \ No newline at end of file diff --git a/backend/app/services/host_inventory.py b/backend/app/services/host_inventory.py new file mode 100644 index 0000000..32f0146 --- /dev/null +++ b/backend/app/services/host_inventory.py @@ -0,0 +1,290 @@ +"""Host Inventory Service - builds a deduplicated host-centric network view. + +Scans all datasets in a hunt to identify unique hosts, their IPs, OS, +logged-in users, and network connections between them. +""" + +import re +import logging +from collections import defaultdict +from typing import Any + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# --- Column-name patterns (Velociraptor + generic forensic tools) --- + +_HOST_ID_RE = re.compile( + r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I) +_FQDN_RE = re.compile( + r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|' + r'system_?name|machine_?name|nodename|workstation)$', re.I) +_USERNAME_RE = re.compile( + r'^(user|username|user_?name|logon_?name|account_?name|owner|' + r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I) +_LOCAL_IP_RE = re.compile( + r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I) +_REMOTE_IP_RE = re.compile( + r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I) +_REMOTE_PORT_RE = re.compile( + r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I) +_OS_RE = re.compile( + r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I) +_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$') + +_IGNORE_IPS = frozenset({ + '0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null', +}) +_SYSTEM_DOMAINS = frozenset({ + 'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER', +}) +_SYSTEM_USERS = frozenset({ + 'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE', + 'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3', +}) + + +def _is_valid_ip(v: str) -> bool: + if not v or v in _IGNORE_IPS: + return False + return bool(_IP_VALID_RE.match(v)) + + +def _clean(v: Any) -> str: + s = str(v or '').strip() + return s if s and s not in ('-', 'None', 'null', '') else '' + + +_SYSTEM_USER_RE = re.compile( + r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I) + + +def _extract_username(raw: str) -> str: + """Clean username, stripping domain prefixes and filtering system accounts.""" + if not raw: + return '' + name = raw.strip() + if '\\' in name: + domain, _, name = name.rpartition('\\') + name = name.strip() + if domain.strip().upper() in _SYSTEM_DOMAINS: + if not name or _SYSTEM_USER_RE.match(name): + return '' + if _SYSTEM_USER_RE.match(name): + return '' + return name or '' + + +def _infer_os(fqdn: str) -> str: + u = fqdn.upper() + if 'W10-' in u or 'WIN10' in u: + return 'Windows 10' + if 'W11-' in u or 'WIN11' in u: + return 'Windows 11' + if 'W7-' in u or 'WIN7' in u: + return 'Windows 7' + if 'SRV' in u or 'SERVER' in u or 'DC-' in u: + return 'Windows Server' + if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')): + return 'Linux' + if 'MAC' in u or 'DARWIN' in u: + return 'macOS' + return 'Windows' + + +def _identify_columns(ds: Dataset) -> dict: + norm = ds.normalized_columns or {} + schema = ds.column_schema or {} + raw_cols = list(schema.keys()) if schema else list(norm.keys()) + + result = { + 'host_id': [], 'fqdn': [], 'username': [], + 'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [], + } + + for col in raw_cols: + canonical = (norm.get(col) or '').lower() + lower = col.lower() + + if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')): + result['host_id'].append(col) + + if _FQDN_RE.match(lower) or canonical == 'fqdn': + result['fqdn'].append(col) + + if _USERNAME_RE.match(lower) or canonical in ('username', 'user'): + result['username'].append(col) + + if _LOCAL_IP_RE.match(lower): + result['local_ip'].append(col) + elif _REMOTE_IP_RE.match(lower): + result['remote_ip'].append(col) + + if _REMOTE_PORT_RE.match(lower): + result['remote_port'].append(col) + + if _OS_RE.match(lower) or canonical == 'os': + result['os'].append(col) + + return result + + +async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: + """Build a deduplicated host inventory from all datasets in a hunt. + + Returns dict with 'hosts', 'connections', and 'stats'. + Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count. + """ + ds_result = await db.execute( + select(Dataset).where(Dataset.hunt_id == hunt_id) + ) + all_datasets = ds_result.scalars().all() + + if not all_datasets: + return {"hosts": [], "connections": [], "stats": { + "total_hosts": 0, "total_datasets_scanned": 0, + "total_rows_scanned": 0, + }} + + hosts: dict[str, dict] = {} # fqdn -> host record + ip_to_host: dict[str, str] = {} # local-ip -> fqdn + connections: dict[tuple, int] = defaultdict(int) + total_rows = 0 + ds_with_hosts = 0 + + for ds in all_datasets: + cols = _identify_columns(ds) + if not cols['fqdn'] and not cols['host_id']: + continue + ds_with_hosts += 1 + + batch_size = 5000 + offset = 0 + while True: + rr = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .order_by(DatasetRow.row_index) + .offset(offset).limit(batch_size) + ) + rows = rr.scalars().all() + if not rows: + break + + for ro in rows: + data = ro.data or {} + total_rows += 1 + + fqdn = '' + for c in cols['fqdn']: + fqdn = _clean(data.get(c)) + if fqdn: + break + client_id = '' + for c in cols['host_id']: + client_id = _clean(data.get(c)) + if client_id: + break + + if not fqdn and not client_id: + continue + + host_key = fqdn or client_id + + if host_key not in hosts: + short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn + hosts[host_key] = { + 'id': host_key, + 'hostname': short or client_id, + 'fqdn': fqdn, + 'client_id': client_id, + 'ips': set(), + 'os': '', + 'users': set(), + 'datasets': set(), + 'row_count': 0, + } + + h = hosts[host_key] + h['datasets'].add(ds.name) + h['row_count'] += 1 + if client_id and not h['client_id']: + h['client_id'] = client_id + + for c in cols['username']: + u = _extract_username(_clean(data.get(c))) + if u: + h['users'].add(u) + + for c in cols['local_ip']: + ip = _clean(data.get(c)) + if _is_valid_ip(ip): + h['ips'].add(ip) + ip_to_host[ip] = host_key + + for c in cols['os']: + ov = _clean(data.get(c)) + if ov and not h['os']: + h['os'] = ov + + for c in cols['remote_ip']: + rip = _clean(data.get(c)) + if _is_valid_ip(rip): + rport = '' + for pc in cols['remote_port']: + rport = _clean(data.get(pc)) + if rport: + break + connections[(host_key, rip, rport)] += 1 + + offset += batch_size + if len(rows) < batch_size: + break + + # Post-process hosts + for h in hosts.values(): + if not h['os'] and h['fqdn']: + h['os'] = _infer_os(h['fqdn']) + h['ips'] = sorted(h['ips']) + h['users'] = sorted(h['users']) + h['datasets'] = sorted(h['datasets']) + + # Build connections, resolving IPs to host keys + conn_list = [] + seen = set() + for (src, dst_ip, dst_port), cnt in connections.items(): + if dst_ip in _IGNORE_IPS: + continue + dst_host = ip_to_host.get(dst_ip, '') + if dst_host == src: + continue + key = tuple(sorted([src, dst_host or dst_ip])) + if key in seen: + continue + seen.add(key) + conn_list.append({ + 'source': src, + 'target': dst_host or dst_ip, + 'target_ip': dst_ip, + 'port': dst_port, + 'count': cnt, + }) + + host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True) + + return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + }, + } \ No newline at end of file diff --git a/backend/app/services/host_profiler.py b/backend/app/services/host_profiler.py new file mode 100644 index 0000000..c4f8bb1 --- /dev/null +++ b/backend/app/services/host_profiler.py @@ -0,0 +1,198 @@ +"""Host profiler - per-host deep threat analysis via Wile heavy models.""" + +from __future__ import annotations + +import asyncio +import json +import logging + +import httpx +from sqlalchemy import select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult + +logger = logging.getLogger(__name__) + +HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL +WILE_URL = f"{settings.wile_url}/api/generate" + + +async def _get_triage_summary(db, dataset_id: str) -> str: + result = await db.execute( + select(TriageResult) + .where(TriageResult.dataset_id == dataset_id) + .where(TriageResult.risk_score >= 3.0) + .order_by(TriageResult.risk_score.desc()) + .limit(10) + ) + triages = result.scalars().all() + if not triages: + return "No significant triage findings." + lines = [] + for t in triages: + lines.append( + f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} " + f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}" + ) + return "\n".join(lines) + + +async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict: + result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = result.scalars().all() + + host_data: dict[str, list[dict]] = {} + triage_parts: list[str] = [] + + for ds in datasets: + artifact_type = getattr(ds, "artifact_type", None) or "Unknown" + rows_result = await db.execute( + select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500) + ) + rows = rows_result.scalars().all() + + matching = [] + for r in rows: + data = r.normalized_data or r.data + row_host = ( + data.get("hostname", "") or data.get("Fqdn", "") + or data.get("ClientId", "") or data.get("client_id", "") + ) + if hostname.lower() in str(row_host).lower(): + matching.append(data) + elif fqdn and fqdn.lower() in str(row_host).lower(): + matching.append(data) + + if matching: + host_data[artifact_type] = matching[:50] + triage_info = await _get_triage_summary(db, ds.id) + triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}") + + return { + "artifacts": host_data, + "triage_summary": "\n".join(triage_parts) or "No triage data.", + "artifact_count": sum(len(v) for v in host_data.values()), + } + + +async def profile_host( + hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None, +) -> None: + logger.info("Profiling host %s in hunt %s", hostname, hunt_id) + + async with async_session() as db: + host_data = await _collect_host_data(db, hunt_id, hostname, fqdn) + if host_data["artifact_count"] == 0: + logger.info("No data found for host %s, skipping", hostname) + return + + system_prompt = ( + "You are a senior threat hunting analyst performing deep host analysis.\n" + "You receive consolidated forensic artifacts and prior triage results for a single host.\n\n" + "Provide a comprehensive host threat profile as JSON:\n" + "- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n" + "- risk_level: low/medium/high/critical\n" + "- suspicious_findings: list of specific concerns\n" + "- mitre_techniques: list of MITRE ATT&CK technique IDs\n" + "- timeline_summary: brief timeline of suspicious activity\n" + "- analysis: detailed narrative assessment\n\n" + "Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n" + "Respond with valid JSON only." + ) + + artifact_summary = {} + for art_type, rows in host_data["artifacts"].items(): + artifact_summary[art_type] = [ + {k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20] + ] + + prompt = ( + f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n" + f"## Prior Triage Results\n{host_data['triage_summary']}\n\n" + f"## Artifact Data ({host_data['artifact_count']} total rows)\n" + f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n" + "Provide your comprehensive host threat profile as JSON." + ) + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post( + WILE_URL, + json={ + "model": HEAVY_MODEL, + "prompt": prompt, + "system": system_prompt, + "stream": False, + "options": {"temperature": 0.3, "num_predict": 4096}, + }, + ) + resp.raise_for_status() + llm_text = resp.json().get("response", "") + + from app.services.triage import _parse_llm_response + parsed = _parse_llm_response(llm_text) + + profile = HostProfile( + hunt_id=hunt_id, + hostname=hostname, + fqdn=fqdn, + client_id=client_id, + risk_score=float(parsed.get("risk_score", 0.0)), + risk_level=parsed.get("risk_level", "low"), + artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()}, + timeline_summary=parsed.get("timeline_summary", ""), + suspicious_findings=parsed.get("suspicious_findings", []), + mitre_techniques=parsed.get("mitre_techniques", []), + llm_analysis=parsed.get("analysis", llm_text[:5000]), + model_used=HEAVY_MODEL, + node_used="wile", + ) + db.add(profile) + await db.commit() + logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level) + + except Exception as e: + logger.error("Failed to profile host %s: %s", hostname, e) + profile = HostProfile( + hunt_id=hunt_id, hostname=hostname, fqdn=fqdn, + risk_score=0.0, risk_level="unknown", + llm_analysis=f"Error: {e}", + model_used=HEAVY_MODEL, node_used="wile", + ) + db.add(profile) + await db.commit() + + +async def profile_all_hosts(hunt_id: str) -> None: + logger.info("Starting host profiling for hunt %s", hunt_id) + + async with async_session() as db: + result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = result.scalars().all() + + hostnames: dict[str, str | None] = {} + for ds in datasets: + rows_result = await db.execute( + select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000) + ) + for r in rows_result.scalars().all(): + data = r.normalized_data or r.data + host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname") + if host and str(host).strip(): + h = str(host).strip() + if h not in hostnames: + hostnames[h] = data.get("fqdn") or data.get("Fqdn") + + logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id) + + semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY) + + async def _bounded(hostname: str, fqdn: str | None): + async with semaphore: + await profile_host(hunt_id, hostname, fqdn) + + tasks = [_bounded(h, f) for h, f in hostnames.items()] + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames)) \ No newline at end of file diff --git a/backend/app/services/ioc_extractor.py b/backend/app/services/ioc_extractor.py new file mode 100644 index 0000000..889bfa7 --- /dev/null +++ b/backend/app/services/ioc_extractor.py @@ -0,0 +1,210 @@ +"""IOC extraction service extract indicators of compromise from dataset rows. + +Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes, +email addresses, URLs, and file paths that look suspicious. +""" + +import re +import logging +from collections import defaultdict +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# Patterns + +_IPV4 = re.compile( + r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b' +) +_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b') +_DOMAIN = re.compile( + r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)' + r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|' + r'online|site|club|win|work|download|stream|gdn|bid|review|racing|' + r'loan|date|faith|accountant|cricket|science|trade|party|men)\b', + re.IGNORECASE, +) +_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b') +_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b') +_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b') +_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b') +_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE) + +# Private / reserved IPs to skip +_PRIVATE_NETS = re.compile( + r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)' +) + +PATTERNS = { + 'ipv4': _IPV4, + 'ipv6': _IPV6, + 'domain': _DOMAIN, + 'md5': _MD5, + 'sha1': _SHA1, + 'sha256': _SHA256, + 'email': _EMAIL, + 'url': _URL, +} + + +def _is_private_ip(ip: str) -> bool: + return bool(_PRIVATE_NETS.match(ip)) + + +def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]: + """Extract all IOC types from a block of text.""" + result: dict[str, set[str]] = defaultdict(set) + for ioc_type, pattern in PATTERNS.items(): + for match in pattern.findall(text): + val = match.strip().lower() if ioc_type != 'url' else match.strip() + # Filter private IPs + if ioc_type == 'ipv4' and skip_private and _is_private_ip(val): + continue + # Filter hex strings that are too generic (< 32 chars not a hash) + result[ioc_type].add(val) + return result + + +async def extract_iocs_from_dataset( + dataset_id: str, + db: AsyncSession, + max_rows: int = 5000, + skip_private: bool = True, +) -> dict[str, list[str]]: + """Extract IOCs from all rows of a dataset. + + Returns {ioc_type: [sorted unique values]}. + """ + # Load rows in batches + all_iocs: dict[str, set[str]] = defaultdict(set) + offset = 0 + batch_size = 500 + + while offset < max_rows: + result = await db.execute( + select(DatasetRow.data) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(offset) + .limit(batch_size) + ) + rows = result.scalars().all() + if not rows: + break + + for data in rows: + # Flatten all values to a single string for scanning + text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data) + batch_iocs = extract_iocs_from_text(text, skip_private) + for ioc_type, values in batch_iocs.items(): + all_iocs[ioc_type].update(values) + + offset += batch_size + + # Convert sets to sorted lists + return {k: sorted(v) for k, v in all_iocs.items() if v} + + +async def extract_host_groups( + hunt_id: str, + db: AsyncSession, +) -> list[dict]: + """Group all data by hostname across datasets in a hunt. + + Returns a list of host group dicts with dataset count, total rows, + artifact types, and time range. + """ + # Get all datasets for this hunt + result = await db.execute( + select(Dataset).where(Dataset.hunt_id == hunt_id) + ) + ds_list = result.scalars().all() + if not ds_list: + return [] + + # Known host columns (check normalized data first, then raw) + HOST_COLS = [ + 'hostname', 'host', 'computer_name', 'computername', 'system', + 'machine', 'device_name', 'devicename', 'endpoint', + 'ClientId', 'Fqdn', 'client_id', 'fqdn', + ] + + hosts: dict[str, dict] = {} + + for ds in ds_list: + # Sample first few rows to find host column + sample_result = await db.execute( + select(DatasetRow.data, DatasetRow.normalized_data) + .where(DatasetRow.dataset_id == ds.id) + .limit(5) + ) + samples = sample_result.all() + if not samples: + continue + + # Find which host column exists + host_col = None + for row_data, norm_data in samples: + check = norm_data if norm_data else row_data + if not isinstance(check, dict): + continue + for col in HOST_COLS: + if col in check and check[col]: + host_col = col + break + if host_col: + break + + if not host_col: + continue + + # Count rows per host in this dataset + all_rows_result = await db.execute( + select(DatasetRow.data, DatasetRow.normalized_data) + .where(DatasetRow.dataset_id == ds.id) + ) + all_rows = all_rows_result.all() + for row_data, norm_data in all_rows: + check = norm_data if norm_data else row_data + if not isinstance(check, dict): + continue + host_val = check.get(host_col, '') + if not host_val or not isinstance(host_val, str): + continue + host_val = host_val.strip() + if not host_val: + continue + + if host_val not in hosts: + hosts[host_val] = { + 'hostname': host_val, + 'dataset_ids': set(), + 'total_rows': 0, + 'artifact_types': set(), + 'first_seen': None, + 'last_seen': None, + } + hosts[host_val]['dataset_ids'].add(ds.id) + hosts[host_val]['total_rows'] += 1 + if ds.artifact_type: + hosts[host_val]['artifact_types'].add(ds.artifact_type) + + # Convert to output format + result_list = [] + for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True): + result_list.append({ + 'hostname': h['hostname'], + 'dataset_count': len(h['dataset_ids']), + 'total_rows': h['total_rows'], + 'artifact_types': sorted(h['artifact_types']), + 'first_seen': None, # TODO: extract from timestamp columns + 'last_seen': None, + 'risk_score': None, # TODO: link to host profiles + }) + + return result_list \ No newline at end of file diff --git a/backend/app/services/job_queue.py b/backend/app/services/job_queue.py new file mode 100644 index 0000000..f218cdf --- /dev/null +++ b/backend/app/services/job_queue.py @@ -0,0 +1,316 @@ +"""Async job queue for background AI tasks. + +Manages triage, profiling, report generation, anomaly detection, +and data queries as trackable jobs with status, progress, and +cancellation support. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Coroutine, Optional + +logger = logging.getLogger(__name__) + + +class JobStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobType(str, Enum): + TRIAGE = "triage" + HOST_PROFILE = "host_profile" + REPORT = "report" + ANOMALY = "anomaly" + QUERY = "query" + + +@dataclass +class Job: + id: str + job_type: JobType + status: JobStatus = JobStatus.QUEUED + progress: float = 0.0 # 0-100 + message: str = "" + result: Any = None + error: str | None = None + created_at: float = field(default_factory=time.time) + started_at: float | None = None + completed_at: float | None = None + params: dict = field(default_factory=dict) + _cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) + + @property + def elapsed_ms(self) -> int: + end = self.completed_at or time.time() + start = self.started_at or self.created_at + return int((end - start) * 1000) + + def to_dict(self) -> dict: + return { + "id": self.id, + "job_type": self.job_type.value, + "status": self.status.value, + "progress": round(self.progress, 1), + "message": self.message, + "error": self.error, + "created_at": self.created_at, + "started_at": self.started_at, + "completed_at": self.completed_at, + "elapsed_ms": self.elapsed_ms, + "params": self.params, + } + + @property + def is_cancelled(self) -> bool: + return self._cancel_event.is_set() + + def cancel(self): + self._cancel_event.set() + self.status = JobStatus.CANCELLED + self.completed_at = time.time() + self.message = "Cancelled by user" + + +class JobQueue: + """In-memory async job queue with concurrency control. + + Jobs are tracked by ID and can be listed, polled, or cancelled. + A configurable number of workers process jobs from the queue. + """ + + def __init__(self, max_workers: int = 3): + self._jobs: dict[str, Job] = {} + self._queue: asyncio.Queue[str] = asyncio.Queue() + self._max_workers = max_workers + self._workers: list[asyncio.Task] = [] + self._handlers: dict[JobType, Callable] = {} + self._started = False + + def register_handler( + self, + job_type: JobType, + handler: Callable[[Job], Coroutine], + ): + """Register an async handler for a job type. + + Handler signature: async def handler(job: Job) -> Any + The handler can update job.progress and job.message during execution. + It should check job.is_cancelled periodically and return early. + """ + self._handlers[job_type] = handler + logger.info(f"Registered handler for {job_type.value}") + + async def start(self): + """Start worker tasks.""" + if self._started: + return + self._started = True + for i in range(self._max_workers): + task = asyncio.create_task(self._worker(i)) + self._workers.append(task) + logger.info(f"Job queue started with {self._max_workers} workers") + + async def stop(self): + """Stop all workers.""" + self._started = False + for w in self._workers: + w.cancel() + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + logger.info("Job queue stopped") + + def submit(self, job_type: JobType, **params) -> Job: + """Submit a new job. Returns the Job object immediately.""" + job = Job( + id=str(uuid.uuid4()), + job_type=job_type, + params=params, + ) + self._jobs[job.id] = job + self._queue.put_nowait(job.id) + logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") + return job + + def get_job(self, job_id: str) -> Job | None: + return self._jobs.get(job_id) + + def cancel_job(self, job_id: str) -> bool: + job = self._jobs.get(job_id) + if not job: + return False + if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): + return False + job.cancel() + return True + + def list_jobs( + self, + status: JobStatus | None = None, + job_type: JobType | None = None, + limit: int = 50, + ) -> list[dict]: + """List jobs, newest first.""" + jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True) + if status: + jobs = [j for j in jobs if j.status == status] + if job_type: + jobs = [j for j in jobs if j.job_type == job_type] + return [j.to_dict() for j in jobs[:limit]] + + def get_stats(self) -> dict: + """Get queue statistics.""" + by_status = {} + for j in self._jobs.values(): + by_status[j.status.value] = by_status.get(j.status.value, 0) + 1 + return { + "total": len(self._jobs), + "queued": self._queue.qsize(), + "by_status": by_status, + "workers": self._max_workers, + "active_workers": sum( + 1 for j in self._jobs.values() if j.status == JobStatus.RUNNING + ), + } + + def cleanup(self, max_age_seconds: float = 3600): + """Remove old completed/failed/cancelled jobs.""" + now = time.time() + to_remove = [ + jid for jid, j in self._jobs.items() + if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + and (now - j.created_at) > max_age_seconds + ] + for jid in to_remove: + del self._jobs[jid] + if to_remove: + logger.info(f"Cleaned up {len(to_remove)} old jobs") + + async def _worker(self, worker_id: int): + """Worker loop: pull jobs from queue and execute handlers.""" + logger.info(f"Worker {worker_id} started") + while self._started: + try: + job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + job = self._jobs.get(job_id) + if not job or job.is_cancelled: + continue + + handler = self._handlers.get(job.job_type) + if not handler: + job.status = JobStatus.FAILED + job.error = f"No handler for {job.job_type.value}" + job.completed_at = time.time() + logger.error(f"No handler for job type {job.job_type.value}") + continue + + job.status = JobStatus.RUNNING + job.started_at = time.time() + job.message = "Running..." + logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") + + try: + result = await handler(job) + if not job.is_cancelled: + job.status = JobStatus.COMPLETED + job.progress = 100.0 + job.result = result + job.message = "Completed" + job.completed_at = time.time() + logger.info( + f"Worker {worker_id}: completed {job.id} " + f"in {job.elapsed_ms}ms" + ) + except Exception as e: + if not job.is_cancelled: + job.status = JobStatus.FAILED + job.error = str(e) + job.message = f"Failed: {e}" + job.completed_at = time.time() + logger.error( + f"Worker {worker_id}: failed {job.id}: {e}", + exc_info=True, + ) + + +# Singleton + job handlers + +job_queue = JobQueue(max_workers=3) + + +async def _handle_triage(job: Job): + """Triage handler.""" + from app.services.triage import triage_dataset + dataset_id = job.params.get("dataset_id") + job.message = f"Triaging dataset {dataset_id}" + results = await triage_dataset(dataset_id) + return {"count": len(results) if results else 0} + + +async def _handle_host_profile(job: Job): + """Host profiling handler.""" + from app.services.host_profiler import profile_all_hosts, profile_host + hunt_id = job.params.get("hunt_id") + hostname = job.params.get("hostname") + if hostname: + job.message = f"Profiling host {hostname}" + await profile_host(hunt_id, hostname) + return {"hostname": hostname} + else: + job.message = f"Profiling all hosts in hunt {hunt_id}" + await profile_all_hosts(hunt_id) + return {"hunt_id": hunt_id} + + +async def _handle_report(job: Job): + """Report generation handler.""" + from app.services.report_generator import generate_report + hunt_id = job.params.get("hunt_id") + job.message = f"Generating report for hunt {hunt_id}" + report = await generate_report(hunt_id) + return {"report_id": report.id if report else None} + + +async def _handle_anomaly(job: Job): + """Anomaly detection handler.""" + from app.services.anomaly_detector import detect_anomalies + dataset_id = job.params.get("dataset_id") + k = job.params.get("k", 3) + threshold = job.params.get("threshold", 0.35) + job.message = f"Detecting anomalies in dataset {dataset_id}" + results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold) + return {"count": len(results) if results else 0} + + +async def _handle_query(job: Job): + """Data query handler (non-streaming).""" + from app.services.data_query import query_dataset + dataset_id = job.params.get("dataset_id") + question = job.params.get("question", "") + mode = job.params.get("mode", "quick") + job.message = f"Querying dataset {dataset_id}" + answer = await query_dataset(dataset_id, question, mode) + return {"answer": answer} + + +def register_all_handlers(): + """Register all job handlers.""" + job_queue.register_handler(JobType.TRIAGE, _handle_triage) + job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile) + job_queue.register_handler(JobType.REPORT, _handle_report) + job_queue.register_handler(JobType.ANOMALY, _handle_anomaly) + job_queue.register_handler(JobType.QUERY, _handle_query) \ No newline at end of file diff --git a/backend/app/services/load_balancer.py b/backend/app/services/load_balancer.py new file mode 100644 index 0000000..3ccc57c --- /dev/null +++ b/backend/app/services/load_balancer.py @@ -0,0 +1,193 @@ +"""Smart load balancer for Wile & Roadrunner LLM nodes. + +Tracks active jobs per node, health status, and routes new work +to the least-busy healthy node. Periodically pings both nodes +to maintain an up-to-date health map. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from app.config import settings + +logger = logging.getLogger(__name__) + + +class NodeId(str, Enum): + WILE = "wile" + ROADRUNNER = "roadrunner" + + +class WorkloadTier(str, Enum): + """What kind of workload is this?""" + HEAVY = "heavy" # 70B models, deep analysis, reports + FAST = "fast" # 7-14B models, triage, quick queries + EMBEDDING = "embed" # bge-m3 embeddings + ANY = "any" + + +@dataclass +class NodeStatus: + node_id: NodeId + url: str + healthy: bool = True + last_check: float = 0.0 + active_jobs: int = 0 + total_completed: int = 0 + total_errors: int = 0 + avg_latency_ms: float = 0.0 + _latencies: list[float] = field(default_factory=list) + + def record_completion(self, latency_ms: float): + self.active_jobs = max(0, self.active_jobs - 1) + self.total_completed += 1 + self._latencies.append(latency_ms) + # Rolling average of last 50 + if len(self._latencies) > 50: + self._latencies = self._latencies[-50:] + self.avg_latency_ms = sum(self._latencies) / len(self._latencies) + + def record_error(self): + self.active_jobs = max(0, self.active_jobs - 1) + self.total_errors += 1 + + def record_start(self): + self.active_jobs += 1 + + +class LoadBalancer: + """Routes LLM work to the least-busy healthy node. + + Node capabilities: + - Wile: Heavy models (70B), code models (32B) + - Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision + """ + + # Which nodes can handle which tiers + TIER_NODES = { + WorkloadTier.HEAVY: [NodeId.WILE], + WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE], + WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER], + WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE], + } + + def __init__(self): + self._nodes: dict[NodeId, NodeStatus] = { + NodeId.WILE: NodeStatus( + node_id=NodeId.WILE, + url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}", + ), + NodeId.ROADRUNNER: NodeStatus( + node_id=NodeId.ROADRUNNER, + url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}", + ), + } + self._lock = asyncio.Lock() + self._health_task: Optional[asyncio.Task] = None + + async def start_health_loop(self, interval: float = 30.0): + """Start background health-check loop.""" + if self._health_task and not self._health_task.done(): + return + self._health_task = asyncio.create_task(self._health_loop(interval)) + logger.info("Load balancer health loop started (%.0fs interval)", interval) + + async def stop_health_loop(self): + if self._health_task: + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + pass + self._health_task = None + + async def _health_loop(self, interval: float): + while True: + try: + await self.check_health() + except Exception as e: + logger.warning(f"Health check error: {e}") + await asyncio.sleep(interval) + + async def check_health(self): + """Ping both nodes and update status.""" + import httpx + async with httpx.AsyncClient(timeout=5) as client: + for nid, status in self._nodes.items(): + try: + resp = await client.get(f"{status.url}/api/tags") + status.healthy = resp.status_code == 200 + except Exception: + status.healthy = False + status.last_check = time.time() + logger.debug( + f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} " + f"(active={status.active_jobs})" + ) + + def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId: + """Select the best node for a workload tier. + + Strategy: among healthy nodes that support the tier, + pick the one with fewest active jobs. + Falls back to any node if none healthy. + """ + candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE]) + + # Filter to healthy candidates + healthy = [ + nid for nid in candidates + if self._nodes[nid].healthy + ] + + if not healthy: + logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate") + healthy = candidates + + # Pick least busy + best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs) + return best + + def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId: + """Select node and mark a job started.""" + node = self.select_node(tier) + self._nodes[node].record_start() + logger.info( + f"LB: dispatched {tier.value} -> {node.value} " + f"(active={self._nodes[node].active_jobs})" + ) + return node + + def release(self, node: NodeId, latency_ms: float = 0, error: bool = False): + """Mark a job completed on a node.""" + status = self._nodes.get(node) + if not status: + return + if error: + status.record_error() + else: + status.record_completion(latency_ms) + + def get_status(self) -> dict: + """Get current load balancer status.""" + return { + nid.value: { + "healthy": s.healthy, + "active_jobs": s.active_jobs, + "total_completed": s.total_completed, + "total_errors": s.total_errors, + "avg_latency_ms": round(s.avg_latency_ms, 1), + "last_check": s.last_check, + } + for nid, s in self._nodes.items() + } + + +# Singleton +lb = LoadBalancer() \ No newline at end of file diff --git a/backend/app/services/report_generator.py b/backend/app/services/report_generator.py new file mode 100644 index 0000000..ef0d57e --- /dev/null +++ b/backend/app/services/report_generator.py @@ -0,0 +1,198 @@ +"""Report generator - debate-powered hunt report generation using Wile + Roadrunner.""" + +from __future__ import annotations + +import json +import logging +import time + +import httpx +from sqlalchemy import select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import ( + Dataset, HostProfile, HuntReport, TriageResult, +) +from app.services.triage import _parse_llm_response + +logger = logging.getLogger(__name__) + +WILE_URL = f"{settings.wile_url}/api/generate" +ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" +HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL +FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" + + +async def _llm_call(url: str, model: str, system: str, prompt: str, timeout: float = 300.0) -> str: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + url, + json={ + "model": model, + "prompt": prompt, + "system": system, + "stream": False, + "options": {"temperature": 0.3, "num_predict": 8192}, + }, + ) + resp.raise_for_status() + return resp.json().get("response", "") + + +async def _gather_evidence(db, hunt_id: str) -> dict: + ds_result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = ds_result.scalars().all() + + dataset_summary = [] + all_triage = [] + for ds in datasets: + ds_info = { + "name": ds.name, + "artifact_type": getattr(ds, "artifact_type", "Unknown"), + "row_count": ds.row_count or 0, + } + dataset_summary.append(ds_info) + + triage_result = await db.execute( + select(TriageResult) + .where(TriageResult.dataset_id == ds.id) + .where(TriageResult.risk_score >= 3.0) + .order_by(TriageResult.risk_score.desc()) + .limit(15) + ) + for t in triage_result.scalars().all(): + all_triage.append({ + "dataset": ds.name, + "artifact_type": ds_info["artifact_type"], + "rows": f"{t.row_start}-{t.row_end}", + "risk_score": t.risk_score, + "verdict": t.verdict, + "findings": t.findings[:5] if t.findings else [], + "indicators": t.suspicious_indicators[:5] if t.suspicious_indicators else [], + "mitre": t.mitre_techniques or [], + }) + + profile_result = await db.execute( + select(HostProfile) + .where(HostProfile.hunt_id == hunt_id) + .order_by(HostProfile.risk_score.desc()) + ) + profiles = profile_result.scalars().all() + host_summaries = [] + for p in profiles: + host_summaries.append({ + "hostname": p.hostname, + "risk_score": p.risk_score, + "risk_level": p.risk_level, + "findings": p.suspicious_findings[:5] if p.suspicious_findings else [], + "mitre": p.mitre_techniques or [], + "timeline": (p.timeline_summary or "")[:300], + }) + + return { + "datasets": dataset_summary, + "triage_findings": all_triage[:30], + "host_profiles": host_summaries, + "total_datasets": len(datasets), + "total_rows": sum(d["row_count"] for d in dataset_summary), + "high_risk_hosts": len([h for h in host_summaries if h["risk_score"] >= 7.0]), + } + + +async def generate_report(hunt_id: str) -> None: + logger.info("Generating report for hunt %s", hunt_id) + start = time.monotonic() + + async with async_session() as db: + report = HuntReport( + hunt_id=hunt_id, + status="generating", + models_used=[HEAVY_MODEL, FAST_MODEL], + ) + db.add(report) + await db.commit() + await db.refresh(report) + report_id = report.id + + try: + evidence = await _gather_evidence(db, hunt_id) + evidence_text = json.dumps(evidence, indent=1, default=str)[:12000] + + # Phase 1: Wile initial analysis + logger.info("Report phase 1: Wile initial analysis") + phase1 = await _llm_call( + WILE_URL, HEAVY_MODEL, + system=( + "You are a senior threat intelligence analyst writing a hunt report.\n" + "Analyze all evidence and produce a structured threat assessment.\n" + "Include: executive summary, detailed findings per host, MITRE mapping,\n" + "IOC table, risk rankings, and actionable recommendations.\n" + "Use markdown formatting. Be thorough and specific." + ), + prompt=f"Hunt evidence:\n{evidence_text}\n\nProduce your initial threat assessment.", + ) + + # Phase 2: Roadrunner critical review + logger.info("Report phase 2: Roadrunner critical review") + phase2 = await _llm_call( + ROADRUNNER_URL, FAST_MODEL, + system=( + "You are a critical reviewer of threat hunt reports.\n" + "Review the initial assessment and identify:\n" + "- Missing correlations or overlooked indicators\n" + "- False positive risks or overblown findings\n" + "- Additional MITRE techniques that should be mapped\n" + "- Gaps in recommendations\n" + "Be specific and constructive. Respond in markdown." + ), + prompt=f"Evidence:\n{evidence_text[:4000]}\n\nInitial Assessment:\n{phase1[:6000]}\n\nProvide your critical review.", + timeout=120.0, + ) + + # Phase 3: Wile final synthesis + logger.info("Report phase 3: Wile final synthesis") + synthesis_prompt = ( + f"Original evidence:\n{evidence_text[:6000]}\n\n" + f"Initial assessment:\n{phase1[:5000]}\n\n" + f"Critical review:\n{phase2[:3000]}\n\n" + "Produce the FINAL hunt report incorporating the review feedback.\n" + "Return JSON with these keys:\n" + "- executive_summary: 2-3 paragraph executive summary\n" + "- findings: list of {title, severity, description, evidence, mitre_ids}\n" + "- recommendations: list of {priority, action, rationale}\n" + "- mitre_mapping: dict of technique_id -> {name, description, evidence}\n" + "- ioc_table: list of {type, value, context, confidence}\n" + "- host_risk_summary: list of {hostname, risk_score, risk_level, key_findings}\n" + "Respond with valid JSON only." + ) + phase3_text = await _llm_call( + WILE_URL, HEAVY_MODEL, + system="You are producing the final, definitive threat hunt report. Incorporate all feedback. Respond with valid JSON only.", + prompt=synthesis_prompt, + ) + + parsed = _parse_llm_response(phase3_text) + elapsed_ms = int((time.monotonic() - start) * 1000) + + full_report = f"# Threat Hunt Report\n\n{phase1}\n\n---\n## Review Notes\n{phase2}\n\n---\n## Final Synthesis\n{phase3_text}" + + report.status = "complete" + report.exec_summary = parsed.get("executive_summary", phase1[:2000]) + report.full_report = full_report + report.findings = parsed.get("findings", []) + report.recommendations = parsed.get("recommendations", []) + report.mitre_mapping = parsed.get("mitre_mapping", {}) + report.ioc_table = parsed.get("ioc_table", []) + report.host_risk_summary = parsed.get("host_risk_summary", []) + report.generation_time_ms = elapsed_ms + await db.commit() + + logger.info("Report %s complete in %dms", report_id, elapsed_ms) + + except Exception as e: + logger.error("Report generation failed for hunt %s: %s", hunt_id, e) + report.status = "error" + report.exec_summary = f"Report generation failed: {e}" + report.generation_time_ms = int((time.monotonic() - start) * 1000) + await db.commit() \ No newline at end of file diff --git a/backend/app/services/triage.py b/backend/app/services/triage.py new file mode 100644 index 0000000..09be225 --- /dev/null +++ b/backend/app/services/triage.py @@ -0,0 +1,170 @@ +"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner.""" + +from __future__ import annotations + +import json +import logging +import re + +import httpx +from sqlalchemy import func, select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import Dataset, DatasetRow, TriageResult + +logger = logging.getLogger(__name__) + +DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" +ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" + +ARTIFACT_FOCUS = { + "Windows.System.Pslist": "Look for: suspicious parent-child, LOLBins, unsigned, injection indicators, abnormal paths.", + "Windows.Network.Netstat": "Look for: C2 beaconing, unusual ports, connections to rare IPs, non-browser high-port listeners.", + "Windows.System.Services": "Look for: services in temp dirs, misspelled names, unsigned ServiceDll, unusual start modes.", + "Windows.Forensics.Prefetch": "Look for: recon tools, lateral movement tools, rarely-run executables with high run counts.", + "Windows.EventLogs.EvtxHunter": "Look for: logon type 10/3 anomalies, service installs, PowerShell script blocks, clearing.", + "Windows.Sys.Autoruns": "Look for: recently added entries, entries in temp/user dirs, encoded commands, suspicious DLLs.", + "Windows.Registry.Finder": "Look for: run keys, image file execution options, hidden services, encoded payloads.", + "Windows.Search.FileFinder": "Look for: files in unusual locations, recently modified system files, known tool names.", +} + + +def _parse_llm_response(text: str) -> dict: + text = text.strip() + fence = re.search(r"`(?:json)?\s*\n?(.*?)\n?\s*`", text, re.DOTALL) + if fence: + text = fence.group(1).strip() + try: + return json.loads(text) + except json.JSONDecodeError: + brace = text.find("{") + bracket = text.rfind("}") + if brace != -1 and bracket != -1 and bracket > brace: + try: + return json.loads(text[brace : bracket + 1]) + except json.JSONDecodeError: + pass + return {"raw_response": text[:3000]} + + +async def triage_dataset(dataset_id: str) -> None: + logger.info("Starting triage for dataset %s", dataset_id) + + async with async_session() as db: + ds_result = await db.execute( + select(Dataset).where(Dataset.id == dataset_id) + ) + dataset = ds_result.scalar_one_or_none() + if not dataset: + logger.error("Dataset %s not found", dataset_id) + return + + artifact_type = getattr(dataset, "artifact_type", None) or "Unknown" + focus = ARTIFACT_FOCUS.get(artifact_type, "Analyze for any suspicious indicators.") + + count_result = await db.execute( + select(func.count()).where(DatasetRow.dataset_id == dataset_id) + ) + total_rows = count_result.scalar() or 0 + + batch_size = settings.TRIAGE_BATCH_SIZE + suspicious_count = 0 + offset = 0 + + while offset < total_rows: + if suspicious_count >= settings.TRIAGE_MAX_SUSPICIOUS_ROWS: + logger.info("Reached suspicious row cap for dataset %s", dataset_id) + break + + rows_result = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_number) + .offset(offset) + .limit(batch_size) + ) + rows = rows_result.scalars().all() + if not rows: + break + + batch_data = [] + for r in rows: + data = r.normalized_data or r.data + compact = {k: str(v)[:200] for k, v in data.items() if v} + batch_data.append(compact) + + system_prompt = f"""You are a cybersecurity triage analyst. Analyze this batch of {artifact_type} forensic data. +{focus} + +Return JSON with: +- risk_score: 0.0 (benign) to 10.0 (critical threat) +- verdict: "clean", "suspicious", "malicious", or "inconclusive" +- findings: list of key observations +- suspicious_indicators: list of specific IOCs or anomalies +- mitre_techniques: list of MITRE ATT&CK IDs if applicable + +Be precise. Only flag genuinely suspicious items. Respond with valid JSON only.""" + + prompt = f"Rows {offset+1}-{offset+len(rows)} of {total_rows}:\n{json.dumps(batch_data, default=str)[:6000]}" + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post( + ROADRUNNER_URL, + json={ + "model": DEFAULT_FAST_MODEL, + "prompt": prompt, + "system": system_prompt, + "stream": False, + "options": {"temperature": 0.2, "num_predict": 2048}, + }, + ) + resp.raise_for_status() + result = resp.json() + llm_text = result.get("response", "") + + parsed = _parse_llm_response(llm_text) + risk = float(parsed.get("risk_score", 0.0)) + + triage = TriageResult( + dataset_id=dataset_id, + row_start=offset, + row_end=offset + len(rows) - 1, + risk_score=risk, + verdict=parsed.get("verdict", "inconclusive"), + findings=parsed.get("findings", []), + suspicious_indicators=parsed.get("suspicious_indicators", []), + mitre_techniques=parsed.get("mitre_techniques", []), + model_used=DEFAULT_FAST_MODEL, + node_used="roadrunner", + ) + db.add(triage) + await db.commit() + + if risk >= settings.TRIAGE_ESCALATION_THRESHOLD: + suspicious_count += len(rows) + + logger.debug( + "Triage batch %d-%d: risk=%.1f verdict=%s", + offset, offset + len(rows) - 1, risk, triage.verdict, + ) + + except Exception as e: + logger.error("Triage batch %d failed: %s", offset, e) + triage = TriageResult( + dataset_id=dataset_id, + row_start=offset, + row_end=offset + len(rows) - 1, + risk_score=0.0, + verdict="error", + findings=[f"Error: {e}"], + model_used=DEFAULT_FAST_MODEL, + node_used="roadrunner", + ) + db.add(triage) + await db.commit() + + offset += batch_size + + logger.info("Triage complete for dataset %s", dataset_id) \ No newline at end of file diff --git a/backend/run.py b/backend/run.py index c8fa0b6..1bb8002 100644 --- a/backend/run.py +++ b/backend/run.py @@ -13,5 +13,5 @@ if __name__ == "__main__": "app.main:app", host="0.0.0.0", port=8000, - reload=True, - ) + reload=False, + ) \ No newline at end of file diff --git a/frontend/nginx.conf b/frontend/nginx.conf index 959445f..432f927 100644 --- a/frontend/nginx.conf +++ b/frontend/nginx.conf @@ -15,10 +15,10 @@ server { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; - proxy_read_timeout 120s; + proxy_read_timeout 300s; } - # SPA fallback — serve index.html for all non-file routes + # SPA fallback serve index.html for all non-file routes location / { try_files $uri $uri/ /index.html; } @@ -28,4 +28,4 @@ server { expires 1y; add_header Cache-Control "public, immutable"; } -} +} \ No newline at end of file diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 4ec96b1..bf3df36 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ /** - * ThreatHunt — MUI-powered analyst-assist platform. + * ThreatHunt MUI-powered analyst-assist platform. */ import React, { useState, useCallback } from 'react'; @@ -18,6 +18,7 @@ import ScienceIcon from '@mui/icons-material/Science'; import CompareArrowsIcon from '@mui/icons-material/CompareArrows'; import GppMaybeIcon from '@mui/icons-material/GppMaybe'; import HubIcon from '@mui/icons-material/Hub'; +import AssessmentIcon from '@mui/icons-material/Assessment'; import { SnackbarProvider } from 'notistack'; import theme from './theme'; @@ -32,6 +33,7 @@ import HypothesisTracker from './components/HypothesisTracker'; import CorrelationView from './components/CorrelationView'; import AUPScanner from './components/AUPScanner'; import NetworkMap from './components/NetworkMap'; +import AnalysisDashboard from './components/AnalysisDashboard'; const DRAWER_WIDTH = 240; @@ -42,13 +44,14 @@ const NAV: NavItem[] = [ { label: 'Hunts', path: '/hunts', icon: }, { label: 'Datasets', path: '/datasets', icon: }, { label: 'Upload', path: '/upload', icon: }, + { label: 'AI Analysis', path: '/analysis', icon: }, { label: 'Agent', path: '/agent', icon: }, { label: 'Enrichment', path: '/enrichment', icon: }, { label: 'Annotations', path: '/annotations', icon: }, { label: 'Hypotheses', path: '/hypotheses', icon: }, { label: 'Correlation', path: '/correlation', icon: }, - { label: 'Network Map', path: '/network', icon: }, - { label: 'AUP Scanner', path: '/aup', icon: }, + { label: 'Network Map', path: '/network', icon: }, + { label: 'AUP Scanner', path: '/aup', icon: }, ]; function Shell() { @@ -109,6 +112,7 @@ function Shell() { } /> } /> } /> + } /> } /> } /> } /> @@ -135,4 +139,4 @@ function App() { ); } -export default App; +export default App; \ No newline at end of file diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index d83cabd..37476ac 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,11 +1,11 @@ /* ==================================================================== - ThreatHunt API Client — mirrors every backend endpoint. + ThreatHunt API Client -- mirrors every backend endpoint. All requests go through the CRA proxy (see package.json "proxy"). ==================================================================== */ const BASE = ''; // proxied to http://localhost:8000 by CRA -// ── Helpers ────────────────────────────────────────────────────────── +// -- Helpers -- let authToken: string | null = localStorage.getItem('th_token'); @@ -36,7 +36,7 @@ async function api( return res.text() as unknown as T; } -// ── Auth ───────────────────────────────────────────────────────────── +// -- Auth -- export interface UserPayload { id: string; username: string; email: string; @@ -63,7 +63,7 @@ export const auth = { me: () => api('/api/auth/me'), }; -// ── Hunts ──────────────────────────────────────────────────────────── +// -- Hunts -- export interface Hunt { id: string; name: string; description: string | null; status: string; @@ -82,7 +82,7 @@ export const hunts = { delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }), }; -// ── Datasets ───────────────────────────────────────────────────────── +// -- Datasets -- export interface DatasetSummary { id: string; name: string; filename: string; source_tool: string | null; @@ -92,6 +92,8 @@ export interface DatasetSummary { file_size_bytes: number; encoding: string | null; delimiter: string | null; time_range_start: string | null; time_range_end: string | null; hunt_id: string | null; created_at: string; + processing_status?: string; artifact_type?: string | null; + error_message?: string | null; file_path?: string | null; } export interface UploadResult { @@ -155,7 +157,7 @@ export const datasets = { delete: (id: string) => api(`/api/datasets/${id}`, { method: 'DELETE' }), }; -// ── Agent ──────────────────────────────────────────────────────────── +// -- Agent -- export interface AssistRequest { query: string; @@ -198,7 +200,7 @@ export const agent = { }, }; -// ── Annotations ────────────────────────────────────────────────────── +// -- Annotations -- export interface AnnotationData { id: string; row_id: number | null; dataset_id: string | null; @@ -224,7 +226,7 @@ export const annotations = { delete: (id: string) => api(`/api/annotations/${id}`, { method: 'DELETE' }), }; -// ── Hypotheses ─────────────────────────────────────────────────────── +// -- Hypotheses -- export interface HypothesisData { id: string; hunt_id: string | null; title: string; description: string | null; @@ -249,7 +251,7 @@ export const hypotheses = { delete: (id: string) => api(`/api/hypotheses/${id}`, { method: 'DELETE' }), }; -// ── Enrichment ─────────────────────────────────────────────────────── +// -- Enrichment -- export interface EnrichmentResult { ioc_value: string; ioc_type: string; source: string; verdict: string; @@ -274,7 +276,7 @@ export const enrichment = { status: () => api>('/api/enrichment/status'), }; -// ── Correlation ────────────────────────────────────────────────────── +// -- Correlation -- export interface CorrelationResult { hunt_ids: string[]; summary: string; total_correlations: number; @@ -292,7 +294,7 @@ export const correlation = { api<{ ioc_value: string; occurrences: any[]; total: number }>(`/api/correlation/ioc/${encodeURIComponent(ioc_value)}`), }; -// ── Reports ────────────────────────────────────────────────────────── +// -- Reports -- export const reports = { json: (huntId: string) => @@ -305,13 +307,13 @@ export const reports = { api>(`/api/reports/hunt/${huntId}/summary`), }; -// ── Root / misc ────────────────────────────────────────────────────── +// -- Root / misc -- export const misc = { root: () => api<{ name: string; version: string; status: string }>('/'), }; -// ── AUP Keywords ───────────────────────────────────────────────────── +// -- AUP Keywords -- export interface KeywordOut { id: number; theme_id: string; value: string; is_regex: boolean; created_at: string; @@ -368,3 +370,216 @@ export const keywords = { quickScan: (datasetId: string) => api(`/api/keywords/scan/quick?dataset_id=${encodeURIComponent(datasetId)}`), }; + + +// -- Analysis (Phase 2+) -- + +export interface TriageResultData { + id: string; dataset_id: string; row_start: number; row_end: number; + risk_score: number; verdict: string; + findings: any[] | null; suspicious_indicators: any[] | null; + mitre_techniques: any[] | null; + model_used: string | null; node_used: string | null; +} + +export interface HostProfileData { + id: string; hunt_id: string; hostname: string; fqdn: string | null; + risk_score: number; risk_level: string; + artifact_summary: Record | null; + timeline_summary: string | null; + suspicious_findings: any[] | null; + mitre_techniques: any[] | null; + llm_analysis: string | null; + model_used: string | null; +} + +export interface HuntReportData { + id: string; hunt_id: string; status: string; + exec_summary: string | null; full_report: string | null; + findings: any[] | null; recommendations: any[] | null; + mitre_mapping: Record | null; + ioc_table: any[] | null; host_risk_summary: any[] | null; + models_used: any[] | null; generation_time_ms: number | null; +} + +export interface AnomalyResultData { + id: string; dataset_id: string; row_id: number | null; + anomaly_score: number; distance_from_centroid: number | null; + cluster_id: number | null; is_outlier: boolean; + explanation: string | null; +} + +export interface HostGroupData { + hostname: string; + dataset_count: number; + total_rows: number; + artifact_types: string[]; + first_seen: string | null; + last_seen: string | null; + risk_score: number | null; +} + +// -- Job queue types (Phase 10) -- + +export interface JobData { + id: string; + job_type: string; + status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled'; + progress: number; + message: string; + error: string | null; + created_at: number; + started_at: number | null; + completed_at: number | null; + elapsed_ms: number; + params: Record; +} + +export interface JobStats { + total: number; + queued: number; + by_status: Record; + workers: number; + active_workers: number; +} + +export interface LBNodeStatus { + healthy: boolean; + active_jobs: number; + total_completed: number; + total_errors: number; + avg_latency_ms: number; + last_check: number; +} + +export const analysis = { + // Triage + triageResults: (datasetId: string, minRisk = 0) => + api(`/api/analysis/triage/${datasetId}?min_risk=${minRisk}`), + triggerTriage: (datasetId: string) => + api<{ status: string; dataset_id: string }>(`/api/analysis/triage/${datasetId}`, { method: 'POST' }), + + // Host profiles + hostProfiles: (huntId: string, minRisk = 0) => + api(`/api/analysis/profiles/${huntId}?min_risk=${minRisk}`), + triggerAllProfiles: (huntId: string) => + api<{ status: string; hunt_id: string }>(`/api/analysis/profiles/${huntId}`, { method: 'POST' }), + triggerHostProfile: (huntId: string, hostname: string) => + api<{ status: string }>(`/api/analysis/profiles/${huntId}/${encodeURIComponent(hostname)}`, { method: 'POST' }), + + // Reports + listReports: (huntId: string) => + api(`/api/analysis/reports/${huntId}`), + getReport: (huntId: string, reportId: string) => + api(`/api/analysis/reports/${huntId}/${reportId}`), + generateReport: (huntId: string) => + api<{ status: string; hunt_id: string }>(`/api/analysis/reports/${huntId}/generate`, { method: 'POST' }), + + // Anomaly detection + anomalies: (datasetId: string, outliersOnly = false) => + api(`/api/analysis/anomalies/${datasetId}?outliers_only=${outliersOnly}`), + triggerAnomalyDetection: (datasetId: string, k = 3, threshold = 0.35) => + api<{ status: string; dataset_id: string }>( + `/api/analysis/anomalies/${datasetId}?k=${k}&threshold=${threshold}`, { method: 'POST' }, + ), + + // IOC extraction + extractIocs: (datasetId: string) => + api<{ dataset_id: string; iocs: Record; total: number }>( + `/api/analysis/iocs/${datasetId}`, + ), + + // Host grouping + hostGroups: (huntId: string) => + api<{ hunt_id: string; hosts: HostGroupData[] }>( + `/api/analysis/hosts/${huntId}`, + ), + + // Data query (Phase 9) - SSE streaming + queryStream: async (datasetId: string, question: string, mode: string = 'quick'): Promise => { + const headers: Record = { 'Content-Type': 'application/json' }; + if (authToken) headers['Authorization'] = `Bearer ${authToken}`; + return fetch(`${BASE}/api/analysis/query/${datasetId}`, { + method: 'POST', + headers, + body: JSON.stringify({ question, mode }), + }); + }, + + // Data query (Phase 9) - sync + querySync: (datasetId: string, question: string, mode: string = 'quick') => + api<{ dataset_id: string; question: string; answer: string; mode: string }>( + `/api/analysis/query/${datasetId}/sync`, { + method: 'POST', + body: JSON.stringify({ question, mode }), + }, + ), + + // Job queue (Phase 10) + listJobs: (status?: string, jobType?: string, limit = 50) => { + const q = new URLSearchParams(); + if (status) q.set('status', status); + if (jobType) q.set('job_type', jobType); + q.set('limit', String(limit)); + return api<{ jobs: JobData[]; stats: JobStats }>(`/api/analysis/jobs?${q}`); + }, + getJob: (jobId: string) => + api(`/api/analysis/jobs/${jobId}`), + cancelJob: (jobId: string) => + api<{ status: string; job_id: string }>(`/api/analysis/jobs/${jobId}`, { method: 'DELETE' }), + submitJob: (jobType: string, params: Record = {}) => + api<{ job_id: string; status: string; job_type: string }>( + `/api/analysis/jobs/submit/${jobType}`, { + method: 'POST', + body: JSON.stringify(params), + }, + ), + + // Load balancer (Phase 10) + lbStatus: () => + api>('/api/analysis/lb/status'), + lbCheck: () => + api>('/api/analysis/lb/check', { method: 'POST' }), +}; + +// -- Network Topology -- + +export interface InventoryHost { + id: string; + hostname: string; + fqdn: string; + client_id: string; + ips: string[]; + os: string; + users: string[]; + datasets: string[]; + row_count: number; +} + +export interface InventoryConnection { + source: string; + target: string; + target_ip: string; + port: string; + count: number; +} + +export interface InventoryStats { + total_hosts: number; + total_datasets_scanned: number; + datasets_with_hosts: number; + total_rows_scanned: number; + hosts_with_ips: number; + hosts_with_users: number; +} + +export interface HostInventory { + hosts: InventoryHost[]; + connections: InventoryConnection[]; + stats: InventoryStats; +} + +export const network = { + hostInventory: (huntId: string) => + api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}`), +}; \ No newline at end of file diff --git a/frontend/src/components/AnalysisDashboard.tsx b/frontend/src/components/AnalysisDashboard.tsx new file mode 100644 index 0000000..f79253d --- /dev/null +++ b/frontend/src/components/AnalysisDashboard.tsx @@ -0,0 +1,818 @@ +/** + * AnalysisDashboard -- 6-tab view covering the full AI analysis pipeline: + * 1. Triage results + * 2. Host profiles + * 3. Reports + * 4. Anomalies + * 5. Ask Data (natural language query with SSE streaming) -- Phase 9 + * 6. Jobs & Load Balancer status -- Phase 10 + */ + +import React, { useState, useEffect, useCallback, useRef } from 'react'; +import { + Box, Typography, Tabs, Tab, Paper, Button, Chip, Stack, CircularProgress, + Table, TableBody, TableCell, TableContainer, TableHead, TableRow, + Accordion, AccordionSummary, AccordionDetails, Alert, Select, MenuItem, + FormControl, InputLabel, LinearProgress, Tooltip, IconButton, Divider, + Card, CardContent, CardActions, Grid, TextField, ToggleButton, + ToggleButtonGroup, +} from '@mui/material'; +import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; +import PlayArrowIcon from '@mui/icons-material/PlayArrow'; +import RefreshIcon from '@mui/icons-material/Refresh'; +import AssessmentIcon from '@mui/icons-material/Assessment'; +import SecurityIcon from '@mui/icons-material/Security'; +import PersonIcon from '@mui/icons-material/Person'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; +import ShieldIcon from '@mui/icons-material/Shield'; +import BubbleChartIcon from '@mui/icons-material/BubbleChart'; +import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer'; +import WorkIcon from '@mui/icons-material/Work'; +import SendIcon from '@mui/icons-material/Send'; +import StopIcon from '@mui/icons-material/Stop'; +import DeleteIcon from '@mui/icons-material/Delete'; +import CheckCircleIcon from '@mui/icons-material/CheckCircle'; +import ErrorIcon from '@mui/icons-material/Error'; +import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; +import CancelIcon from '@mui/icons-material/Cancel'; +import { useSnackbar } from 'notistack'; +import { + analysis, hunts, datasets, + type Hunt, type DatasetSummary, + type TriageResultData, type HostProfileData, type HuntReportData, + type AnomalyResultData, type JobData, type JobStats, type LBNodeStatus, +} from '../api/client'; + +/* helpers */ + +function riskColor(score: number): 'error' | 'warning' | 'info' | 'success' | 'default' { + if (score >= 8) return 'error'; + if (score >= 5) return 'warning'; + if (score >= 2) return 'info'; + return 'success'; +} + +function riskLabel(level: string): 'error' | 'warning' | 'info' | 'success' | 'default' { + if (level === 'critical' || level === 'high') return 'error'; + if (level === 'medium') return 'warning'; + if (level === 'low') return 'success'; + return 'default'; +} + +function fmtMs(ms: number): string { + if (ms < 1000) return `${ms}ms`; + return `${(ms / 1000).toFixed(1)}s`; +} + +function fmtTime(ts: number): string { + if (!ts) return '--'; + return new Date(ts * 1000).toLocaleTimeString(); +} + +const statusIcon = (s: string) => { + switch (s) { + case 'completed': return ; + case 'failed': return ; + case 'running': return ; + case 'queued': return ; + case 'cancelled': return ; + default: return null; + } +}; + +/* TabPanel */ + +function TabPanel({ children, value, index }: { children: React.ReactNode; value: number; index: number }) { + return value === index ? {children} : null; +} + +/* Main component */ + +export default function AnalysisDashboard() { + const { enqueueSnackbar } = useSnackbar(); + const [tab, setTab] = useState(0); + + // Selectors + const [huntList, setHuntList] = useState([]); + const [dsList, setDsList] = useState([]); + const [huntId, setHuntId] = useState(''); + const [dsId, setDsId] = useState(''); + + // Data tabs 0-3 + const [triageResults, setTriageResults] = useState([]); + const [profiles, setProfiles] = useState([]); + const [reports, setReports] = useState([]); + const [anomalies, setAnomalies] = useState([]); + + // Loading states + const [loadingTriage, setLoadingTriage] = useState(false); + const [loadingProfiles, setLoadingProfiles] = useState(false); + const [loadingReports, setLoadingReports] = useState(false); + const [loadingAnomalies, setLoadingAnomalies] = useState(false); + const [triggering, setTriggering] = useState(false); + + // Phase 9: Ask Data + const [queryText, setQueryText] = useState(''); + const [queryMode, setQueryMode] = useState('quick'); + const [queryAnswer, setQueryAnswer] = useState(''); + const [queryStreaming, setQueryStreaming] = useState(false); + const [queryMeta, setQueryMeta] = useState | null>(null); + const [queryDone, setQueryDone] = useState | null>(null); + const abortRef = useRef(null); + const answerRef = useRef(null); + + // Phase 10: Jobs + const [jobs, setJobs] = useState([]); + const [jobStats, setJobStats] = useState(null); + const [lbStatus, setLbStatus] = useState | null>(null); + const [loadingJobs, setLoadingJobs] = useState(false); + + // Load hunts and datasets + useEffect(() => { + hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {}); + datasets.list(0, 200).then(r => setDsList(r.datasets)).catch(() => {}); + }, []); + + useEffect(() => { + if (huntList.length > 0 && !huntId) setHuntId(huntList[0].id); + }, [huntList, huntId]); + + useEffect(() => { + if (dsList.length > 0 && !dsId) setDsId(dsList[0].id); + }, [dsList, dsId]); + + /* Fetch triage results */ + const fetchTriage = useCallback(async () => { + if (!dsId) return; + setLoadingTriage(true); + try { + const data = await analysis.triageResults(dsId); + setTriageResults(data); + } catch (e: any) { + enqueueSnackbar(`Triage fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingTriage(false); } + }, [dsId, enqueueSnackbar]); + + const fetchProfiles = useCallback(async () => { + if (!huntId) return; + setLoadingProfiles(true); + try { + const data = await analysis.hostProfiles(huntId); + setProfiles(data); + } catch (e: any) { + enqueueSnackbar(`Profiles fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingProfiles(false); } + }, [huntId, enqueueSnackbar]); + + const fetchReports = useCallback(async () => { + if (!huntId) return; + setLoadingReports(true); + try { + const data = await analysis.listReports(huntId); + setReports(data); + } catch (e: any) { + enqueueSnackbar(`Reports fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingReports(false); } + }, [huntId, enqueueSnackbar]); + + const fetchAnomalies = useCallback(async () => { + if (!dsId) return; + setLoadingAnomalies(true); + try { + const data = await analysis.anomalies(dsId); + setAnomalies(data); + } catch (e: any) { + enqueueSnackbar('Anomaly fetch failed: ' + e.message, { variant: 'error' }); + } finally { setLoadingAnomalies(false); } + }, [dsId, enqueueSnackbar]); + + const fetchJobs = useCallback(async () => { + setLoadingJobs(true); + try { + const data = await analysis.listJobs(); + setJobs(data.jobs); + setJobStats(data.stats); + } catch (e: any) { + enqueueSnackbar('Jobs fetch failed: ' + e.message, { variant: 'error' }); + } finally { setLoadingJobs(false); } + }, [enqueueSnackbar]); + + const fetchLbStatus = useCallback(async () => { + try { + const data = await analysis.lbStatus(); + setLbStatus(data); + } catch {} + }, []); + + // Load data when selectors change + useEffect(() => { if (dsId) fetchTriage(); }, [dsId, fetchTriage]); + useEffect(() => { if (huntId) { fetchProfiles(); fetchReports(); } }, [huntId, fetchProfiles, fetchReports]); + + // Auto-refresh jobs when on jobs tab + useEffect(() => { + if (tab === 5) { + fetchJobs(); + fetchLbStatus(); + const iv = setInterval(() => { fetchJobs(); fetchLbStatus(); }, 5000); + return () => clearInterval(iv); + } + }, [tab, fetchJobs, fetchLbStatus]); + + /* Trigger actions */ + const doTriggerTriage = useCallback(async () => { + if (!dsId) return; + setTriggering(true); + try { + await analysis.triggerTriage(dsId); + enqueueSnackbar('Triage started', { variant: 'info' }); + setTimeout(fetchTriage, 5000); + } catch (e: any) { + enqueueSnackbar(`Triage trigger failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [dsId, enqueueSnackbar, fetchTriage]); + + const doTriggerProfiles = useCallback(async () => { + if (!huntId) return; + setTriggering(true); + try { + await analysis.triggerAllProfiles(huntId); + enqueueSnackbar('Host profiling started', { variant: 'info' }); + setTimeout(fetchProfiles, 10000); + } catch (e: any) { + enqueueSnackbar(`Profile trigger failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [huntId, enqueueSnackbar, fetchProfiles]); + + const doGenerateReport = useCallback(async () => { + if (!huntId) return; + setTriggering(true); + try { + await analysis.generateReport(huntId); + enqueueSnackbar('Report generation started', { variant: 'info' }); + setTimeout(fetchReports, 15000); + } catch (e: any) { + enqueueSnackbar(`Report generation failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [huntId, enqueueSnackbar, fetchReports]); + + const doTriggerAnomalies = useCallback(async () => { + if (!dsId) return; + setTriggering(true); + try { + await analysis.triggerAnomalyDetection(dsId); + enqueueSnackbar('Anomaly detection started', { variant: 'info' }); + setTimeout(fetchAnomalies, 20000); + } catch (e: any) { + enqueueSnackbar('Anomaly trigger failed: ' + e.message, { variant: 'error' }); + } finally { setTriggering(false); } + }, [dsId, enqueueSnackbar, fetchAnomalies]); + + /* Phase 9: Streaming data query */ + const doQuery = useCallback(async () => { + if (!dsId || !queryText.trim()) return; + setQueryStreaming(true); + setQueryAnswer(''); + setQueryMeta(null); + setQueryDone(null); + + const controller = new AbortController(); + abortRef.current = controller; + + try { + const resp = await analysis.queryStream(dsId, queryText.trim(), queryMode); + if (!resp.body) throw new Error('No response body'); + + const reader = resp.body.getReader(); + const decoder = new TextDecoder(); + let buf = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + + const lines = buf.split('\n'); + buf = lines.pop() || ''; + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + try { + const evt = JSON.parse(line.slice(6)); + switch (evt.type) { + case 'token': + setQueryAnswer(prev => prev + evt.content); + if (answerRef.current) { + answerRef.current.scrollTop = answerRef.current.scrollHeight; + } + break; + case 'metadata': + setQueryMeta(evt.dataset); + break; + case 'done': + setQueryDone(evt); + break; + case 'error': + enqueueSnackbar(`Query error: ${evt.message}`, { variant: 'error' }); + break; + } + } catch {} + } + } + } catch (e: any) { + if (e.name !== 'AbortError') { + enqueueSnackbar('Query failed: ' + e.message, { variant: 'error' }); + } + } finally { + setQueryStreaming(false); + abortRef.current = null; + } + }, [dsId, queryText, queryMode, enqueueSnackbar]); + + const stopQuery = useCallback(() => { + if (abortRef.current) { + abortRef.current.abort(); + setQueryStreaming(false); + } + }, []); + + /* Phase 10: Cancel job */ + const doCancelJob = useCallback(async (jobId: string) => { + try { + await analysis.cancelJob(jobId); + enqueueSnackbar('Job cancelled', { variant: 'info' }); + fetchJobs(); + } catch (e: any) { + enqueueSnackbar('Cancel failed: ' + e.message, { variant: 'error' }); + } + }, [enqueueSnackbar, fetchJobs]); + + return ( + + + + AI Analysis + {triggering && } + + + {/* Selectors */} + + + + Hunt + + + + Dataset + + + + + + {/* Tabs */} + setTab(v)} variant="scrollable" scrollButtons="auto" sx={{ mb: 1 }}> + } iconPosition="start" label={`Triage (${triageResults.length})`} /> + } iconPosition="start" label={`Host Profiles (${profiles.length})`} /> + } iconPosition="start" label={`Reports (${reports.length})`} /> + } iconPosition="start" label={`Anomalies (${anomalies.filter(a => a.is_outlier).length})`} /> + } iconPosition="start" label="Ask Data" /> + } iconPosition="start" label={`Jobs${jobStats ? ` (${jobStats.active_workers})` : ''}`} /> + + + {/* Tab 0: Triage */} + + + + + + {loadingTriage && } + {triageResults.length === 0 && !loadingTriage ? ( + No triage results yet. Select a dataset and click "Run Triage". + ) : ( + + + + + RowsRiskVerdict + FindingsMITREModel + + + + {triageResults.map(tr => ( + + {tr.row_start}-{tr.row_end} + + + + {tr.findings?.join('; ') || ''} + + + + {tr.mitre_techniques?.map((t: string, i: number) => ( + + ))} + + + {tr.model_used || ''} + + ))} + +
+
+ )} +
+ + {/* Tab 1: Host Profiles */} + + + + + + {loadingProfiles && } + {profiles.length === 0 && !loadingProfiles ? ( + No host profiles yet. Select a hunt and click "Profile All Hosts". + ) : ( + + {profiles.map(hp => ( + + + + + {hp.hostname} + + + {hp.fqdn && {hp.fqdn}} + + {hp.timeline_summary && ( + + {hp.timeline_summary.slice(0, 300)}{hp.timeline_summary.length > 300 ? '...' : ''} + + )} + {hp.suspicious_findings && hp.suspicious_findings.length > 0 && ( + + + + {hp.suspicious_findings.length} suspicious finding(s) + + + )} + {hp.mitre_techniques && hp.mitre_techniques.length > 0 && ( + + {hp.mitre_techniques.map((t: string, i: number) => ( + + ))} + + )} + + + Model: {hp.model_used || 'N/A'} + + + + ))} + + )} + + + {/* Tab 2: Reports */} + + + + + + {loadingReports && } + {reports.length === 0 && !loadingReports ? ( + No reports yet. Select a hunt and click "Generate Report". + ) : ( + reports.map(rpt => ( + + }> + + + Report - {rpt.status} + {rpt.generation_time_ms && ( + + )} + + + + {rpt.exec_summary && ( + + Executive Summary + {rpt.exec_summary} + + )} + {rpt.findings && rpt.findings.length > 0 && ( + + Findings +
    + {rpt.findings.map((f: any, i: number) => ( +
  • + {typeof f === 'string' ? f : JSON.stringify(f)} +
  • + ))} +
+
+ )} + {rpt.recommendations && rpt.recommendations.length > 0 && ( + + Recommendations +
    + {rpt.recommendations.map((r: any, i: number) => ( +
  • + {typeof r === 'string' ? r : JSON.stringify(r)} +
  • + ))} +
+
+ )} + {rpt.ioc_table && rpt.ioc_table.length > 0 && ( + + IOC Table + + + + + {Object.keys(rpt.ioc_table[0]).map(k => ( + {k} + ))} + + + + {rpt.ioc_table.map((row: any, i: number) => ( + + {Object.values(row).map((v: any, j: number) => ( + {String(v)} + ))} + + ))} + +
+
+
+ )} + {rpt.full_report && ( + + }> + Full Report + + + + {rpt.full_report} + + + + )} + + {rpt.models_used?.map((m: string, i: number) => ( + + ))} + +
+
+ )) + )} +
+ + {/* Tab 3: Anomalies */} + + + + + + {loadingAnomalies && } + {anomalies.length === 0 && !loadingAnomalies ? ( + No anomaly results yet. Select a dataset and click "Detect Anomalies". + ) : ( + <> + + {anomalies.filter(a => a.is_outlier).length} outlier(s) detected out of {anomalies.length} rows + + + + + + RowScore + DistanceClusterOutlier + + + + {anomalies.filter(a => a.is_outlier).concat(anomalies.filter(a => !a.is_outlier).slice(0, 20)).map((a, i) => ( + + {a.row_id ?? ''} + + 0.5 ? 'error' : a.anomaly_score > 0.35 ? 'warning' : 'success'} /> + + {a.distance_from_centroid?.toFixed(4) ?? ''} + + + {a.is_outlier + ? + : } + + + ))} + +
+
+ + )} +
+ + {/* Tab 4: Ask Data (Phase 9) */} + + + + Ask a question about the selected dataset in plain English + + + setQueryText(e.target.value)} + onKeyDown={e => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); doQuery(); } }} + disabled={queryStreaming} + /> + { if (v) setQueryMode(v); }} + > + + Quick + + + Deep + + + {queryStreaming ? ( + + ) : ( + + + + )} + + + + {queryMeta && ( + + Querying {queryMeta.name} ({queryMeta.row_count} rows,{' '} + {queryMeta.sample_rows_shown} sampled) | Mode: {queryMode} + + )} + + {queryStreaming && } + + {queryAnswer && ( + + {queryAnswer} + + )} + + {queryDone && ( + + + + + + + )} + + + {/* Tab 5: Jobs & Load Balancer (Phase 10) */} + + {/* LB Status Cards */} + {lbStatus && ( + + {Object.entries(lbStatus).map(([name, st]) => ( + + + + + {name} + + + + Active: {st.active_jobs} + Done: {st.total_completed} + Errors: {st.total_errors} + Avg: {st.avg_latency_ms.toFixed(0)}ms + + + + + ))} + + )} + + {/* Job queue stats */} + {jobStats && ( + + + + {Object.entries(jobStats.by_status).map(([s, c]) => ( + + ))} + + )} + + + + + + {loadingJobs && } + + {jobs.length === 0 && !loadingJobs ? ( + No jobs yet. Jobs appear here when you trigger triage, profiling, reports, anomaly detection, or data queries. + ) : ( + + + + + Status + Type + Progress + Message + Time + Created + Actions + + + + {jobs.map(j => ( + + + + {statusIcon(j.status)} + {j.status} + + + + + {j.status === 'running' ? ( + + ) : j.status === 'completed' ? ( + 100% + ) : null} + + + {j.error || j.message} + + {fmtMs(j.elapsed_ms)} + {fmtTime(j.created_at)} + + {(j.status === 'queued' || j.status === 'running') && ( + doCancelJob(j.id)}> + + + )} + + + ))} + +
+
+ )} +
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/components/DatasetViewer.tsx b/frontend/src/components/DatasetViewer.tsx index 6ee666b..2e06ea2 100644 --- a/frontend/src/components/DatasetViewer.tsx +++ b/frontend/src/components/DatasetViewer.tsx @@ -144,6 +144,12 @@ export default function DatasetViewer() { {selected.source_tool && } + {selected.artifact_type && } + {selected.processing_status && selected.processing_status !== 'ready' && ( + + )} {selected.ioc_columns && Object.keys(selected.ioc_columns).length > 0 && ( )} diff --git a/frontend/src/components/NetworkMap.tsx b/frontend/src/components/NetworkMap.tsx index 478bbc9..7ede886 100644 --- a/frontend/src/components/NetworkMap.tsx +++ b/frontend/src/components/NetworkMap.tsx @@ -1,477 +1,578 @@ /** - * NetworkMap — interactive hunt-scoped force-directed network graph. + * NetworkMap - Host-centric force-directed network graph. * - * • Select a hunt → loads only that hunt's datasets - * • Nodes = unique IPs / hostnames / domains pulled from IOC columns - * • Edges = "seen together in the same row" co-occurrence - * • Click a node → popover showing hostname, IP, OS, dataset sources, connections - * • Responsive canvas with ResizeObserver - * • Zero extra npm dependencies + * Loads a deduplicated host inventory from the backend, showing one node + * per unique host with hostname, IPs, OS, logged-in users, and network + * connections. No more duplicated IOC nodes. + * + * Features: + * - Calls /api/network/host-inventory for clean, deduped host data + * - HiDPI / Retina canvas rendering + * - Radial-gradient nodes with neon glow effects + * - Curved edges with animated flow on active connections + * - Animated force-directed layout + * - Node drag with springy neighbor physics + * - Glassmorphism toolbar + floating legend overlay + * - Rich popover: hostname, IP, OS, users, datasets + * - Zero extra npm dependencies */ import React, { useEffect, useState, useRef, useCallback, useMemo } from 'react'; import { - Box, Typography, Paper, Stack, Alert, Chip, Button, TextField, + Box, Typography, Paper, Stack, Alert, Chip, TextField, LinearProgress, FormControl, InputLabel, Select, MenuItem, - Popover, Divider, IconButton, + Popover, Divider, IconButton, Tooltip, Fade, useTheme, } from '@mui/material'; import RefreshIcon from '@mui/icons-material/Refresh'; import CloseIcon from '@mui/icons-material/Close'; import ZoomInIcon from '@mui/icons-material/ZoomIn'; import ZoomOutIcon from '@mui/icons-material/ZoomOut'; import CenterFocusStrongIcon from '@mui/icons-material/CenterFocusStrong'; -import { datasets, hunts, type Hunt, type DatasetSummary } from '../api/client'; +import HubIcon from '@mui/icons-material/Hub'; +import SearchIcon from '@mui/icons-material/Search'; +import ComputerIcon from '@mui/icons-material/Computer'; +import { + hunts, network, + type Hunt, type InventoryHost, type InventoryConnection, type InventoryStats, +} from '../api/client'; -// ── Graph primitives ───────────────────────────────────────────────── +// == Graph primitives ===================================================== -type NodeType = 'ip' | 'hostname' | 'domain' | 'url'; - -interface NodeMeta { - hostnames: Set; - ips: Set; - os: Set; - datasets: Set; - type: NodeType; -} +type NodeType = 'host' | 'external_ip'; interface GNode { id: string; label: string; x: number; y: number; vx: number; vy: number; radius: number; color: string; count: number; - meta: { hostnames: string[]; ips: string[]; os: string[]; datasets: string[]; type: NodeType }; + pinned?: boolean; + meta: { + type: NodeType; + hostname: string; + fqdn: string; + client_id: string; + ips: string[]; + os: string; + users: string[]; + datasets: string[]; + row_count: number; + }; } interface GEdge { source: string; target: string; weight: number } interface Graph { nodes: GNode[]; edges: GEdge[] } const TYPE_COLORS: Record = { - ip: '#3b82f6', hostname: '#22c55e', domain: '#eab308', url: '#8b5cf6', + host: '#60a5fa', + external_ip: '#fbbf24', +}; +const GLOW_COLORS: Record = { + host: 'rgba(96,165,250,0.45)', + external_ip: 'rgba(251,191,36,0.35)', }; -// ── Helpers: find context columns from dataset schema ──────────────── +// == Build graph from inventory ========================================== -/** Best-effort detection of hostname, IP, and OS columns from raw column names + normalized mapping. */ -function findContextColumns(ds: DatasetSummary) { - const norm = ds.normalized_columns || {}; - const schema = ds.column_schema || {}; - const rawCols = Object.keys(schema).length > 0 ? Object.keys(schema) : Object.keys(norm); +function buildGraphFromInventory( + hosts: InventoryHost[], connections: InventoryConnection[], + canvasW: number, canvasH: number, +): Graph { + const nodeMap = new Map(); - const hostCols: string[] = []; - const ipCols: string[] = []; - const osCols: string[] = []; - - for (const raw of rawCols) { - const canonical = norm[raw] || ''; - const lower = raw.toLowerCase(); - // Hostname columns - if (canonical === 'hostname' || /^(hostname|host|fqdn|computer_?name|system_?name|machinename)$/i.test(lower)) { - hostCols.push(raw); - } - // IP columns - if (['src_ip', 'dst_ip', 'ip_address'].includes(canonical) || /^(ip|ip_?address|src_?ip|dst_?ip|source_?ip|dest_?ip)$/i.test(lower)) { - ipCols.push(raw); - } - // OS columns (best-effort — raw name scan + normalized canonical) - if (canonical === 'os' || /^(os|operating_?system|os_?version|os_?name|platform|os_?type)$/i.test(lower)) { - osCols.push(raw); - } - } - return { hostCols, ipCols, osCols }; -} - -function cleanVal(v: any): string { - const s = (v ?? '').toString().trim(); - return (s && s !== '-' && s !== '0.0.0.0' && s !== '::') ? s : ''; -} - -// ── Build graph with per-node metadata ─────────────────────────────── - -interface RowBatch { - rows: Record[]; - iocColumns: Record; - dsName: string; - ds: DatasetSummary; -} - -function buildGraph(allBatches: RowBatch[], canvasW: number, canvasH: number): Graph { - const countMap = new Map(); - const edgeMap = new Map(); - const metaMap = new Map(); - - const getOrCreateMeta = (id: string, type: NodeType): NodeMeta => { - let m = metaMap.get(id); - if (!m) { m = { hostnames: new Set(), ips: new Set(), os: new Set(), datasets: new Set(), type }; metaMap.set(id, m); } - return m; - }; - - for (const { rows, iocColumns, dsName, ds } of allBatches) { - // IOC columns that produce graph nodes - const iocEntries = Object.entries(iocColumns).filter(([, t]) => { - const typ = Array.isArray(t) ? t[0] : t; - return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url'; - }).map(([col, t]) => { - const typ = (Array.isArray(t) ? t[0] : t) as NodeType; - return { col, typ }; - }); - - if (iocEntries.length === 0) continue; - - // Context columns for enrichment - const ctx = findContextColumns(ds); - - for (const row of rows) { - // Collect IOC values for this row (nodes + edges) - const vals: { v: string; typ: NodeType }[] = []; - for (const { col, typ } of iocEntries) { - const v = cleanVal(row[col]); - if (v) vals.push({ v, typ }); - } - const unique = [...new Map(vals.map(x => [x.v, x])).values()]; - - // Count occurrences - for (const { v } of unique) countMap.set(v, (countMap.get(v) ?? 0) + 1); - - // Create edges (co-occurrence) - for (let i = 0; i < unique.length; i++) { - for (let j = i + 1; j < unique.length; j++) { - const key = [unique[i].v, unique[j].v].sort().join('||'); - edgeMap.set(key, (edgeMap.get(key) ?? 0) + 1); - } - } - - // Extract context values from this row - const rowHosts = ctx.hostCols.map(c => cleanVal(row[c])).filter(Boolean); - const rowIps = ctx.ipCols.map(c => cleanVal(row[c])).filter(Boolean); - const rowOs = ctx.osCols.map(c => cleanVal(row[c])).filter(Boolean); - - // Attach context to each node in this row - for (const { v, typ } of unique) { - const meta = getOrCreateMeta(v, typ); - meta.datasets.add(dsName); - for (const h of rowHosts) meta.hostnames.add(h); - for (const ip of rowIps) meta.ips.add(ip); - for (const o of rowOs) meta.os.add(o); - } - } - } - - const nodes: GNode[] = [...countMap.entries()].map(([id, count]) => { - const raw = metaMap.get(id); - const type: NodeType = raw?.type || 'ip'; - return { - id, label: id, count, + // Create host nodes + for (const h of hosts) { + const r = Math.max(8, Math.min(26, 6 + Math.sqrt(h.row_count / 100) * 3)); + nodeMap.set(h.id, { + id: h.id, + label: h.hostname || h.fqdn || h.client_id, x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75, y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65, - vx: 0, vy: 0, - radius: Math.max(5, Math.min(18, 4 + Math.sqrt(count) * 1.6)), - color: TYPE_COLORS[type], + vx: 0, vy: 0, radius: r, + color: TYPE_COLORS.host, + count: h.row_count, meta: { - hostnames: [...(raw?.hostnames ?? [])], - ips: [...(raw?.ips ?? [])], - os: [...(raw?.os ?? [])], - datasets: [...(raw?.datasets ?? [])], - type, + type: 'host' as NodeType, + hostname: h.hostname, + fqdn: h.fqdn, + client_id: h.client_id, + ips: h.ips, + os: h.os, + users: h.users, + datasets: h.datasets, + row_count: h.row_count, }, - }; - }); + }); + } - const edges: GEdge[] = [...edgeMap.entries()].map(([key, weight]) => { - const [source, target] = key.split('||'); - return { source, target, weight }; - }); + // Create edges + external IP nodes (for unresolved remote IPs) + const edges: GEdge[] = []; + for (const c of connections) { + if (!nodeMap.has(c.target)) { + nodeMap.set(c.target, { + id: c.target, + label: c.target_ip || c.target, + x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75, + y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65, + vx: 0, vy: 0, radius: 6, + color: TYPE_COLORS.external_ip, + count: c.count, + meta: { + type: 'external_ip' as NodeType, + hostname: '', fqdn: '', client_id: '', + ips: [c.target_ip || c.target], + os: '', users: [], datasets: [], row_count: 0, + }, + }); + } + edges.push({ source: c.source, target: c.target, weight: c.count }); + } - return { nodes, edges }; + return { nodes: [...nodeMap.values()], edges }; } -// ── Force simulation ───────────────────────────────────────────────── +// == Simulation =========================================================== -function simulate(graph: Graph, cx: number, cy: number, steps = 120) { +function simulationStep(graph: Graph, cx: number, cy: number, alpha: number) { const { nodes, edges } = graph; const nodeMap = new Map(nodes.map(n => [n.id, n])); - const k = 80; - const repulsion = 6000; - const damping = 0.85; + const k = 120; + const repulsion = 12000; + const damping = 0.82; - for (let step = 0; step < steps; step++) { - for (let i = 0; i < nodes.length; i++) { - for (let j = i + 1; j < nodes.length; j++) { - const a = nodes[i], b = nodes[j]; - const dx = b.x - a.x, dy = b.y - a.y; - const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); - const force = repulsion / (dist * dist); - const fx = (dx / dist) * force, fy = (dy / dist) * force; - a.vx -= fx; a.vy -= fy; - b.vx += fx; b.vy += fy; - } - } - for (const e of edges) { - const a = nodeMap.get(e.source), b = nodeMap.get(e.target); - if (!a || !b) continue; + for (let i = 0; i < nodes.length; i++) { + for (let j = i + 1; j < nodes.length; j++) { + const a = nodes[i], b = nodes[j]; + if (a.pinned && b.pinned) continue; const dx = b.x - a.x, dy = b.y - a.y; const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); - const force = (dist - k) * 0.05; + const force = (repulsion * alpha) / (dist * dist); const fx = (dx / dist) * force, fy = (dy / dist) * force; - a.vx += fx; a.vy += fy; - b.vx -= fx; b.vy -= fy; - } - for (const n of nodes) { - n.vx += (cx - n.x) * 0.001; - n.vy += (cy - n.y) * 0.001; - n.vx *= damping; n.vy *= damping; - n.x += n.vx; n.y += n.vy; + if (!a.pinned) { a.vx -= fx; a.vy -= fy; } + if (!b.pinned) { b.vx += fx; b.vy += fy; } } } + for (const e of edges) { + const a = nodeMap.get(e.source), b = nodeMap.get(e.target); + if (!a || !b) continue; + const dx = b.x - a.x, dy = b.y - a.y; + const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); + const force = (dist - k) * 0.06 * alpha; + const fx = (dx / dist) * force, fy = (dy / dist) * force; + if (!a.pinned) { a.vx += fx; a.vy += fy; } + if (!b.pinned) { b.vx -= fx; b.vy -= fy; } + } + for (const n of nodes) { + if (n.pinned) continue; + n.vx += (cx - n.x) * 0.0012 * alpha; + n.vy += (cy - n.y) * 0.0012 * alpha; + n.vx *= damping; n.vy *= damping; + n.x += n.vx; n.y += n.vy; + } } -// ── Viewport (zoom / pan) ──────────────────────────────────────────── +function simulate(graph: Graph, cx: number, cy: number, steps = 150) { + for (let i = 0; i < steps; i++) { + const alpha = 1 - i / steps; + simulationStep(graph, cx, cy, Math.max(0.05, alpha)); + } +} + +// == Viewport ============================================================= interface Viewport { x: number; y: number; scale: number } +const MIN_ZOOM = 0.08; +const MAX_ZOOM = 10; -const MIN_ZOOM = 0.1; -const MAX_ZOOM = 8; +// == Canvas renderer ===================================================== -// ── Canvas renderer ────────────────────────────────────────────────── +const BG_COLOR = '#0a101e'; +const GRID_DOT_COLOR = 'rgba(148,163,184,0.04)'; +const GRID_SPACING = 32; -function drawGraph( - ctx: CanvasRenderingContext2D, graph: Graph, - hovered: string | null, selected: string | null, search: string, - vp: Viewport, +function drawBackground( + ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number, ) { - const { nodes, edges } = graph; - const nodeMap = new Map(nodes.map(n => [n.id, n])); - const matchSet = new Set(); - if (search) { - const lc = search.toLowerCase(); - for (const n of nodes) if (n.label.toLowerCase().includes(lc)) matchSet.add(n.id); - } - - ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height); + ctx.fillStyle = BG_COLOR; + ctx.fillRect(0, 0, w, h); ctx.save(); - ctx.translate(vp.x, vp.y); - ctx.scale(vp.scale, vp.scale); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } + } + ctx.restore(); + const vignette = ctx.createRadialGradient(w / 2, h / 2, w * 0.2, w / 2, h / 2, w * 0.7); + vignette.addColorStop(0, 'rgba(10,16,30,0)'); + vignette.addColorStop(1, 'rgba(10,16,30,0.55)'); + ctx.fillStyle = vignette; + ctx.fillRect(0, 0, w, h); +} - // Edges - for (const e of edges) { +function drawEdges( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + nodeMap: Map, animTime: number, +) { + for (const e of graph.edges) { const a = nodeMap.get(e.source), b = nodeMap.get(e.target); if (!a || !b) continue; const isActive = (hovered && (e.source === hovered || e.target === hovered)) || (selected && (e.source === selected || e.target === selected)); - ctx.beginPath(); - ctx.strokeStyle = isActive ? 'rgba(96,165,250,0.7)' : 'rgba(100,116,139,0.25)'; - ctx.lineWidth = Math.min(4, 0.5 + e.weight * 0.3) / vp.scale; - ctx.moveTo(a.x, a.y); ctx.lineTo(b.x, b.y); ctx.stroke(); - } + const mx = (a.x + b.x) / 2, my = (a.y + b.y) / 2; + const dx = b.x - a.x, dy = b.y - a.y; + const len = Math.sqrt(dx * dx + dy * dy); + const perpScale = Math.min(20, len * 0.08); + const cpx = mx + (-dy / (len || 1)) * perpScale; + const cpy = my + (dx / (len || 1)) * perpScale; - // Nodes - for (const n of nodes) { - const highlighted = hovered === n.id || selected === n.id || (search && matchSet.has(n.id)); - ctx.beginPath(); - ctx.arc(n.x, n.y, n.radius, 0, Math.PI * 2); - ctx.fillStyle = highlighted ? '#fff' : n.color; - ctx.globalAlpha = (search && !matchSet.has(n.id)) ? 0.15 : 1; - ctx.fill(); - ctx.globalAlpha = 1; - if (highlighted) { ctx.strokeStyle = n.color; ctx.lineWidth = 2.5 / vp.scale; ctx.stroke(); } - } + ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); - // Labels — show more labels when zoomed in - const labelThreshold = Math.max(1, Math.round(3 / vp.scale)); - const fontSize = Math.max(8, Math.round(11 / vp.scale)); - ctx.font = `${fontSize}px Inter, sans-serif`; + if (isActive) { + ctx.strokeStyle = 'rgba(96,165,250,0.8)'; + ctx.lineWidth = Math.min(3.5, 1 + e.weight * 0.15); + ctx.setLineDash([6, 4]); ctx.lineDashOffset = -animTime * 0.03; + ctx.stroke(); ctx.setLineDash([]); + ctx.save(); + ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8; + ctx.strokeStyle = 'rgba(96,165,250,0.3)'; + ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2); + ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); + ctx.stroke(); ctx.restore(); + } else { + const alpha = Math.min(0.35, 0.08 + e.weight * 0.01); + ctx.strokeStyle = `rgba(100,116,139,${alpha})`; + ctx.lineWidth = Math.min(2.5, 0.4 + e.weight * 0.08); + ctx.stroke(); + } + } +} + +function drawNodes( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, +) { + const dimmed = search.length > 0; + for (const n of graph.nodes) { + const isHighlight = hovered === n.id || selected === n.id || (search && matchSet.has(n.id)); + const isDim = dimmed && !matchSet.has(n.id); + ctx.save(); + ctx.globalAlpha = isDim ? 0.12 : 1; + + if (isHighlight && !isDim) { + ctx.save(); + ctx.shadowColor = GLOW_COLORS[n.meta.type] || 'rgba(96,165,250,0.4)'; + ctx.shadowBlur = 18; + ctx.beginPath(); ctx.arc(n.x, n.y, n.radius + 4, 0, Math.PI * 2); + ctx.fillStyle = 'rgba(0,0,0,0)'; ctx.fill(); + ctx.restore(); + } + + const grad = ctx.createRadialGradient( + n.x - n.radius * 0.3, n.y - n.radius * 0.3, n.radius * 0.1, n.x, n.y, n.radius, + ); + if (isHighlight && !isDim) { + grad.addColorStop(0, '#ffffff'); + grad.addColorStop(0.4, n.color); + grad.addColorStop(1, n.color); + } else { + grad.addColorStop(0, n.color + 'cc'); + grad.addColorStop(0.5, n.color); + grad.addColorStop(1, n.color + '88'); + } + ctx.beginPath(); ctx.arc(n.x, n.y, n.radius, 0, Math.PI * 2); + ctx.fillStyle = grad; ctx.fill(); + ctx.strokeStyle = isHighlight ? '#ffffff' : (n.color + '55'); + ctx.lineWidth = isHighlight ? 2 : 1; + ctx.stroke(); + + if (n.pinned) { + ctx.beginPath(); ctx.arc(n.x, n.y, 2.5, 0, Math.PI * 2); + ctx.fillStyle = '#ffffff'; ctx.fill(); + } + ctx.restore(); + } +} + +function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, +) { + const dimmed = search.length > 0; + const fontSize = Math.max(9, Math.round(12 / vp.scale)); + ctx.font = `500 ${fontSize}px Inter, system-ui, sans-serif`; ctx.textAlign = 'center'; - for (const n of nodes) { - const show = hovered === n.id || selected === n.id - || (search && matchSet.has(n.id)) || n.count >= labelThreshold; - if (!show) continue; - ctx.fillStyle = (search && !matchSet.has(n.id)) ? 'rgba(241,245,249,0.15)' : '#f1f5f9'; - ctx.fillText(n.label, n.x, n.y - n.radius - 5); - } + ctx.textBaseline = 'bottom'; + const sorted = [...graph.nodes].sort((a, b) => { + const aH = hovered === a.id || selected === a.id || matchSet.has(a.id) ? 1 : 0; + const bH = hovered === b.id || selected === b.id || matchSet.has(b.id) ? 1 : 0; + if (aH !== bH) return aH - bH; + return b.count - a.count; + }); + + for (const n of sorted) { + const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id); + // Always show labels for hosts (since they're deduped and fewer) + const show = isHighlight || n.meta.type === 'host' || n.count >= 2; + if (!show) continue; + const isDim = dimmed && !matchSet.has(n.id); + if (isDim) continue; + + // Two-line label: hostname + IP (if available) + const line1 = n.label; + const line2 = n.meta.ips.length > 0 ? n.meta.ips[0] : ''; + const tw = Math.max(ctx.measureText(line1).width, line2 ? ctx.measureText(line2).width : 0); + const px = 5, py = 2; + const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2; + const lx = n.x, ly = n.y - n.radius - 6; + + const rx = lx - tw / 2 - px; + const ry = ly - totalH; + const rw = tw + px * 2; + const rh = totalH; + const cr = 4; + + ctx.save(); + ctx.globalAlpha = isHighlight ? 0.92 : 0.75; + ctx.fillStyle = 'rgba(10,16,30,0.80)'; + ctx.beginPath(); + ctx.moveTo(rx + cr, ry); ctx.lineTo(rx + rw - cr, ry); + ctx.arcTo(rx + rw, ry, rx + rw, ry + cr, cr); + ctx.lineTo(rx + rw, ry + rh - cr); + ctx.arcTo(rx + rw, ry + rh, rx + rw - cr, ry + rh, cr); + ctx.lineTo(rx + cr, ry + rh); + ctx.arcTo(rx, ry + rh, rx, ry + rh - cr, cr); + ctx.lineTo(rx, ry + cr); + ctx.arcTo(rx, ry, rx + cr, ry, cr); + ctx.closePath(); ctx.fill(); + ctx.strokeStyle = isHighlight ? n.color + 'aa' : 'rgba(148,163,184,0.15)'; + ctx.lineWidth = 0.8; ctx.stroke(); + ctx.restore(); + + // Hostname line + ctx.fillStyle = isHighlight ? '#ffffff' : n.color; + ctx.globalAlpha = isHighlight ? 1 : 0.85; + ctx.fillText(line1, lx, ly - (line2 ? fontSize * 0.5 : 0)); + // IP line (smaller, dimmer) + if (line2) { + ctx.fillStyle = 'rgba(148,163,184,0.6)'; + ctx.fillText(line2, lx, ly + fontSize * 0.5); + } + ctx.globalAlpha = 1; + } +} + +function drawGraph( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, search: string, + vp: Viewport, animTime: number, dpr: number, +) { + const w = ctx.canvas.width, h = ctx.canvas.height; + const nodeMap = new Map(graph.nodes.map(n => [n.id, n])); + const matchSet = new Set(); + if (search) { + const lc = search.toLowerCase(); + for (const n of graph.nodes) { + if (n.label.toLowerCase().includes(lc) + || n.meta.ips.some(ip => ip.includes(lc)) + || n.meta.users.some(u => u.toLowerCase().includes(lc)) + || n.meta.os.toLowerCase().includes(lc) + ) matchSet.add(n.id); + } + } + drawBackground(ctx, w, h, vp, dpr); + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime); + drawNodes(ctx, graph, hovered, selected, search, matchSet); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp); ctx.restore(); } -// ── Hit-test helper (viewport-aware) ───────────────────────────────── +// == Hit-test ============================================================= function screenToWorld( canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, ): { wx: number; wy: number } { const rect = canvas.getBoundingClientRect(); - const cssToCanvas_x = canvas.width / rect.width; - const cssToCanvas_y = canvas.height / rect.height; - const cx = (clientX - rect.left) * cssToCanvas_x; - const cy = (clientY - rect.top) * cssToCanvas_y; - return { wx: (cx - vp.x) / vp.scale, wy: (cy - vp.y) / vp.scale }; + return { wx: (clientX - rect.left - vp.x) / vp.scale, wy: (clientY - rect.top - vp.y) / vp.scale }; } function hitTest( - graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, - vp: Viewport, + graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, ): GNode | null { const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp); for (const n of graph.nodes) { const dx = n.x - wx, dy = n.y - wy; - if (dx * dx + dy * dy < (n.radius + 4) ** 2) return n; + if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n; } return null; } - -// ── Component ──────────────────────────────────────────────────────── +// == Component ============================================================= export default function NetworkMap() { - // Hunt selector + const theme = useTheme(); + const [huntList, setHuntList] = useState([]); const [selectedHuntId, setSelectedHuntId] = useState(''); - // Graph state const [loading, setLoading] = useState(false); const [progress, setProgress] = useState(''); const [error, setError] = useState(''); const [graph, setGraph] = useState(null); + const [stats, setStats] = useState(null); const [hovered, setHovered] = useState(null); const [selectedNode, setSelectedNode] = useState(null); const [search, setSearch] = useState(''); - const [dsCount, setDsCount] = useState(0); - const [totalRows, setTotalRows] = useState(0); - // Node type filters - const [visibleTypes, setVisibleTypes] = useState>( - new Set(['ip', 'hostname', 'domain', 'url']), - ); - - // Canvas sizing const canvasRef = useRef(null); const wrapperRef = useRef(null); const [canvasSize, setCanvasSize] = useState({ w: 900, h: 600 }); - // Viewport (zoom / pan) const vpRef = useRef({ x: 0, y: 0, scale: 1 }); - const [vpScale, setVpScale] = useState(1); // for UI display only + const [vpScale, setVpScale] = useState(1); const isPanning = useRef(false); const panStart = useRef({ x: 0, y: 0 }); + const dragNode = useRef(null); + + const animFrameRef = useRef(0); + const animTimeRef = useRef(0); + const simAlphaRef = useRef(0); + const isAnimatingRef = useRef(false); + const hoveredRef = useRef(null); + const selectedNodeRef = useRef(null); + const searchRef = useRef(''); + const graphRef = useRef(null); - // Popover anchor const [popoverAnchor, setPopoverAnchor] = useState<{ top: number; left: number } | null>(null); - // ── Load hunts on mount ──────────────────────────────────────────── + useEffect(() => { hoveredRef.current = hovered; }, [hovered]); + useEffect(() => { selectedNodeRef.current = selectedNode; }, [selectedNode]); + useEffect(() => { searchRef.current = search; }, [search]); + + // Load hunts on mount useEffect(() => { hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {}); }, []); - // ── Resize observer ──────────────────────────────────────────────── + // Resize observer useEffect(() => { const el = wrapperRef.current; if (!el) return; const ro = new ResizeObserver(entries => { for (const entry of entries) { const w = Math.round(entry.contentRect.width); - if (w > 100) setCanvasSize({ w, h: Math.max(450, Math.round(w * 0.55)) }); + if (w > 100) setCanvasSize({ w, h: Math.max(500, Math.round(w * 0.56)) }); } }); ro.observe(el); return () => ro.disconnect(); }, []); - // ── Load graph for selected hunt ────────────────────────────────── + // HiDPI canvas sizing + useEffect(() => { + const canvas = canvasRef.current; + if (!canvas) return; + const dpr = window.devicePixelRatio || 1; + canvas.width = canvasSize.w * dpr; + canvas.height = canvasSize.h * dpr; + canvas.style.width = canvasSize.w + 'px'; + canvas.style.height = canvasSize.h + 'px'; + }, [canvasSize]); + + // Load host inventory for selected hunt const loadGraph = useCallback(async (huntId: string) => { if (!huntId) return; - setLoading(true); setError(''); setGraph(null); + setLoading(true); setError(''); setGraph(null); setStats(null); setSelectedNode(null); setPopoverAnchor(null); try { - setProgress('Fetching datasets for hunt…'); - const dsRes = await datasets.list(0, 500, huntId); - const dsList = dsRes.datasets; - setDsCount(dsList.length); + setProgress('Building host inventory (scanning all datasets)\u2026'); + const inv = await network.hostInventory(huntId); + setStats(inv.stats); - if (dsList.length === 0) { - setError('This hunt has no datasets. Upload CSV files to this hunt first.'); + if (inv.hosts.length === 0) { + setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.'); setLoading(false); setProgress(''); return; } - const allBatches: RowBatch[] = []; - let rowTotal = 0; - - for (let i = 0; i < dsList.length; i++) { - const ds = dsList[i]; - setProgress(`Loading ${ds.name} (${i + 1}/${dsList.length})…`); - try { - const detail = await datasets.get(ds.id); - const ioc = detail.ioc_columns || {}; - const hasIoc = Object.values(ioc).some(t => { - const typ = Array.isArray(t) ? t[0] : t; - return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url'; - }); - if (hasIoc) { - const r = await datasets.rows(ds.id, 0, 5000); - allBatches.push({ rows: r.rows, iocColumns: ioc, dsName: ds.name, ds: detail }); - rowTotal += r.rows.length; - } - } catch { /* skip failed datasets */ } - } - - setTotalRows(rowTotal); - - if (allBatches.length === 0) { - setError('No datasets in this hunt contain IP/hostname/domain IOC columns.'); - setLoading(false); setProgress(''); - return; - } - - setProgress('Building graph…'); - const g = buildGraph(allBatches, canvasSize.w, canvasSize.h); - if (g.nodes.length === 0) { - setError('No network nodes found in the data.'); - } else { - simulate(g, canvasSize.w / 2, canvasSize.h / 2); - setGraph(g); - } + setProgress(`Building graph for ${inv.stats.total_hosts} hosts\u2026`); + const g = buildGraphFromInventory(inv.hosts, inv.connections, canvasSize.w, canvasSize.h); + simulate(g, canvasSize.w / 2, canvasSize.h / 2, 30); + simAlphaRef.current = 0.8; + setGraph(g); } catch (e: any) { setError(e.message); } setLoading(false); setProgress(''); }, [canvasSize]); - // When hunt changes, load graph useEffect(() => { if (selectedHuntId) loadGraph(selectedHuntId); }, [selectedHuntId, loadGraph]); - // Reset viewport when graph changes useEffect(() => { vpRef.current = { x: 0, y: 0, scale: 1 }; setVpScale(1); }, [graph]); - // Filtered graph — only visible node types + edges between them - const filteredGraph = useMemo(() => { - if (!graph) return null; - const nodes = graph.nodes.filter(n => visibleTypes.has(n.meta.type)); - const nodeIds = new Set(nodes.map(n => n.id)); - const edges = graph.edges.filter(e => nodeIds.has(e.source) && nodeIds.has(e.target)); - return { nodes, edges }; - }, [graph, visibleTypes]); + useEffect(() => { graphRef.current = graph; }, [graph]); - // Toggle a node type filter - const toggleType = useCallback((t: NodeType) => { - setVisibleTypes(prev => { - const next = new Set(prev); - if (next.has(t)) { - // Don't allow all to be hidden - if (next.size > 1) next.delete(t); - } else { - next.add(t); + // Animation loop + const startAnimLoop = useCallback(() => { + if (isAnimatingRef.current) return; + isAnimatingRef.current = true; + const tick = (ts: number) => { + animTimeRef.current = ts; + const canvas = canvasRef.current; + const g = graphRef.current; + if (!canvas || !g) { isAnimatingRef.current = false; return; } + const dpr = window.devicePixelRatio || 1; + const ctx = canvas.getContext('2d'); + if (!ctx) { isAnimatingRef.current = false; return; } + + if (simAlphaRef.current > 0.01) { + simulationStep(g, canvasSize.w / 2, canvasSize.h / 2, simAlphaRef.current); + simAlphaRef.current *= 0.97; + if (simAlphaRef.current < 0.01) simAlphaRef.current = 0; } - return next; - }); - }, []); + drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr); + + const needsAnim = simAlphaRef.current > 0.01 + || hoveredRef.current !== null + || selectedNodeRef.current !== null + || dragNode.current !== null; + if (needsAnim) { + animFrameRef.current = requestAnimationFrame(tick); + } else { + isAnimatingRef.current = false; + } + }; + animFrameRef.current = requestAnimationFrame(tick); + }, [canvasSize]); + + useEffect(() => { + if (graph) startAnimLoop(); + return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; }; + }, [graph, startAnimLoop]); + + useEffect(() => { startAnimLoop(); }, [hovered, selectedNode, search, startAnimLoop]); - // Redraw helper — uses filteredGraph const redraw = useCallback(() => { - if (!filteredGraph || !canvasRef.current) return; + if (!graph || !canvasRef.current) return; const ctx = canvasRef.current.getContext('2d'); - if (ctx) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vpRef.current); - }, [filteredGraph, hovered, selectedNode, search]); + const dpr = window.devicePixelRatio || 1; + if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr); + }, [graph, hovered, selectedNode, search]); - // Redraw on every render-affecting state change - useEffect(() => { redraw(); }, [redraw]); + useEffect(() => { if (!isAnimatingRef.current) redraw(); }, [redraw]); - // ── Mouse wheel → zoom ───────────────────────────────────────────── + // Mouse wheel -> zoom useEffect(() => { const canvas = canvasRef.current; if (!canvas) return; @@ -479,210 +580,329 @@ export default function NetworkMap() { e.preventDefault(); const vp = vpRef.current; const rect = canvas.getBoundingClientRect(); - const cssToCanvasX = canvas.width / rect.width; - const cssToCanvasY = canvas.height / rect.height; - // Mouse position in canvas pixel coords - const mx = (e.clientX - rect.left) * cssToCanvasX; - const my = (e.clientY - rect.top) * cssToCanvasY; - - const zoomFactor = e.deltaY < 0 ? 1.12 : 1 / 1.12; + const mx = e.clientX - rect.left, my = e.clientY - rect.top; + const zoomFactor = e.deltaY < 0 ? 1.15 : 1 / 1.15; const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * zoomFactor)); - // Zoom toward cursor: adjust offset so world-point under cursor stays fixed vp.x = mx - (mx - vp.x) * (newScale / vp.scale); vp.y = my - (my - vp.y) * (newScale / vp.scale); vp.scale = newScale; setVpScale(newScale); - // Immediate redraw (bypass React state for smoothness) - const ctx = canvas.getContext('2d'); - if (ctx && filteredGraph) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vp); + startAnimLoop(); }; canvas.addEventListener('wheel', onWheel, { passive: false }); return () => canvas.removeEventListener('wheel', onWheel); - }, [filteredGraph, hovered, selectedNode, search]); + }, [graph, startAnimLoop]); - // ── Mouse drag → pan ─────────────────────────────────────────────── + // Mouse handlers const onMouseDown = useCallback((e: React.MouseEvent) => { - if (!filteredGraph || !canvasRef.current) return; - const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current); - if (!node) { - isPanning.current = true; - panStart.current = { x: e.clientX, y: e.clientY }; - } - }, [filteredGraph]); + if (!graph || !canvasRef.current) return; + const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + if (node) { dragNode.current = node; node.pinned = true; startAnimLoop(); } + else { isPanning.current = true; panStart.current = { x: e.clientX, y: e.clientY }; } + }, [graph, startAnimLoop]); const onMouseMove = useCallback((e: React.MouseEvent) => { - if (!filteredGraph || !canvasRef.current) return; - + if (!graph || !canvasRef.current) return; + if (dragNode.current) { + const { wx, wy } = screenToWorld(canvasRef.current, e.clientX, e.clientY, vpRef.current); + dragNode.current.x = wx; dragNode.current.y = wy; + dragNode.current.vx = 0; dragNode.current.vy = 0; + if (simAlphaRef.current < 0.15) simAlphaRef.current = 0.15; + startAnimLoop(); return; + } if (isPanning.current) { const vp = vpRef.current; - const rect = canvasRef.current.getBoundingClientRect(); - const cssToCanvasX = canvasRef.current.width / rect.width; - const cssToCanvasY = canvasRef.current.height / rect.height; - vp.x += (e.clientX - panStart.current.x) * cssToCanvasX; - vp.y += (e.clientY - panStart.current.y) * cssToCanvasY; + vp.x += e.clientX - panStart.current.x; + vp.y += e.clientY - panStart.current.y; panStart.current = { x: e.clientX, y: e.clientY }; - redraw(); - return; + redraw(); return; } - - const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); setHovered(node?.id ?? null); - }, [filteredGraph, redraw]); + }, [graph, redraw, startAnimLoop]); - const onMouseUp = useCallback(() => { - isPanning.current = false; - }, []); + const onMouseUp = useCallback(() => { dragNode.current = null; isPanning.current = false; }, []); - // ── Mouse click → select node + show popover ───────────────────── const onClick = useCallback((e: React.MouseEvent) => { - if (!filteredGraph || !canvasRef.current) return; - const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current); - if (node) { - setSelectedNode(node); - setPopoverAnchor({ top: e.clientY, left: e.clientX }); - } else { - setSelectedNode(null); - setPopoverAnchor(null); + if (!graph || !canvasRef.current) return; + const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + if (node) { setSelectedNode(node); setPopoverAnchor({ top: e.clientY, left: e.clientX }); } + else { setSelectedNode(null); setPopoverAnchor(null); } + }, [graph]); + + const onDoubleClick = useCallback((e: React.MouseEvent) => { + if (!graph || !canvasRef.current) return; + const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + if (node && node.pinned) { + node.pinned = false; simAlphaRef.current = Math.max(simAlphaRef.current, 0.3); startAnimLoop(); } - }, [filteredGraph]); + }, [graph, startAnimLoop]); const closePopover = () => { setSelectedNode(null); setPopoverAnchor(null); }; - // ── Zoom controls ────────────────────────────────────────────────── const zoomBy = useCallback((factor: number) => { const vp = vpRef.current; const cw = canvasSize.w, ch = canvasSize.h; const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * factor)); - // Zoom toward canvas center vp.x = cw / 2 - (cw / 2 - vp.x) * (newScale / vp.scale); vp.y = ch / 2 - (ch / 2 - vp.y) * (newScale / vp.scale); - vp.scale = newScale; - setVpScale(newScale); - redraw(); + vp.scale = newScale; setVpScale(newScale); redraw(); }, [canvasSize, redraw]); const resetView = useCallback(() => { - vpRef.current = { x: 0, y: 0, scale: 1 }; - setVpScale(1); - redraw(); + vpRef.current = { x: 0, y: 0, scale: 1 }; setVpScale(1); redraw(); }, [redraw]); - // Count connections for selected node - const connectionCount = selectedNode && filteredGraph - ? filteredGraph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length + const connectionCount = selectedNode && graph + ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length : 0; - // ── Render ───────────────────────────────────────────────────────── + const connectedNodes = useMemo(() => { + if (!selectedNode || !graph) return []; + const neighbors: { id: string; type: NodeType; weight: number }[] = []; + for (const e of graph.edges) { + if (e.source === selectedNode.id) { + const n = graph.nodes.find(x => x.id === e.target); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); + } else if (e.target === selectedNode.id) { + const n = graph.nodes.find(x => x.id === e.source); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); + } + } + return neighbors.sort((a, b) => b.weight - a.weight).slice(0, 12); + }, [selectedNode, graph]); + + const hostCount = graph ? graph.nodes.filter(n => n.meta.type === 'host').length : 0; + const extCount = graph ? graph.nodes.filter(n => n.meta.type === 'external_ip').length : 0; + + const getCursor = () => { + if (dragNode.current) return 'grabbing'; + if (isPanning.current) return 'grabbing'; + if (hovered) return 'pointer'; + return 'grab'; + }; + // == Render ============================================================== return ( - {/* Header row */} - - Network Map - - - Hunt - - - setSearch(e.target.value)} sx={{ width: 200 }} /> - + {/* Glassmorphism toolbar */} + + + + + Network Map + - + + + + + Hunt + + + + setSearch(e.target.value)} + sx={{ width: 200, '& .MuiInputBase-input': { py: 0.8 } }} + slotProps={{ + input: { + startAdornment: , + }, + }} + /> + + + + loadGraph(selectedHuntId)} + disabled={loading || !selectedHuntId} + size="small" + sx={{ bgcolor: 'rgba(96,165,250,0.1)', '&:hover': { bgcolor: 'rgba(96,165,250,0.2)' } }} + > + + + + + + + {/* Stats summary cards */} + {stats && !loading && ( + + {[ + { label: 'Hosts', value: stats.total_hosts, color: TYPE_COLORS.host }, + { label: 'With IPs', value: stats.hosts_with_ips, color: '#34d399' }, + { label: 'With Users', value: stats.hosts_with_users, color: '#a78bfa' }, + { label: 'Datasets Scanned', value: stats.total_datasets_scanned, color: '#fbbf24' }, + { label: 'Rows Scanned', value: stats.total_rows_scanned.toLocaleString(), color: '#f87171' }, + ].map(s => ( + + + {s.label} + + + {s.value} + + + ))} + + )} {/* Loading indicator */} - {loading && ( - - {progress} - + + + + + {progress} + + + - )} + {error && {error}} - {/* Legend — clickable type filters */} - {graph && filteredGraph && ( - - - - - - - {([['ip', 'IP'], ['hostname', 'Host'], ['domain', 'Domain'], ['url', 'URL']] as [NodeType, string][]).map(([type, label]) => { - const active = visibleTypes.has(type); - const count = graph.nodes.filter(n => n.meta.type === type).length; - return ( - toggleType(type)} - sx={{ - bgcolor: active ? TYPE_COLORS[type] : 'transparent', - color: active ? '#fff' : TYPE_COLORS[type], - border: `2px solid ${TYPE_COLORS[type]}`, - fontWeight: 600, - cursor: 'pointer', - opacity: active ? 1 : 0.5, - transition: 'all 0.15s ease', - '&:hover': { opacity: 1 }, - }} - /> - ); - })} - - )} - - {/* Canvas */} - {filteredGraph && ( - + {/* Canvas area */} + {graph && ( + { isPanning.current = false; setHovered(null); }} + onMouseLeave={() => { isPanning.current = false; dragNode.current = null; setHovered(null); }} onClick={onClick} + onDoubleClick={onDoubleClick} /> - {/* Zoom controls overlay */} + + {/* Legend overlay - bottom left */} - zoomBy(1.3)} - sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#f1f5f9', '&:hover': { bgcolor: 'rgba(51,65,85,0.95)' } }} - aria-label="Zoom in"> - zoomBy(1 / 1.3)} - sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#f1f5f9', '&:hover': { bgcolor: 'rgba(51,65,85,0.95)' } }} - aria-label="Zoom out"> - - + } + label={`Hosts (${hostCount})`} + size="small" + sx={{ + bgcolor: TYPE_COLORS.host + '22', color: TYPE_COLORS.host, + border: `1.5px solid ${TYPE_COLORS.host}88`, + fontWeight: 600, fontSize: 11, + }} + /> + {extCount > 0 && ( + + )} + + {/* Stats badge - bottom right */} + + + {graph.nodes.length} nodes + + {'\u00B7'} + + {graph.edges.length} connections + + + + {/* Zoom controls - top right */} + + {[ + { tip: 'Zoom in', icon: , fn: () => zoomBy(1.3) }, + { tip: 'Zoom out', icon: , fn: () => zoomBy(1 / 1.3) }, + { tip: 'Reset view', icon: , fn: resetView }, + ].map(z => ( + + {z.icon} + + ))} + + + + {/* Drag hint */} + + + Drag nodes to reposition {'\u00B7'} Double-click to unpin {'\u00B7'} Scroll to zoom + + )} - {/* Node detail popover */} {selectedNode && ( - - - {selectedNode.label} - - - - - - - {/* Hostnames */} - Hostname - - {selectedNode.meta.hostnames.length > 0 - ? selectedNode.meta.hostnames.join(', ') - : Unknown} - - - {/* IPs */} - IP Address - - {selectedNode.meta.ips.length > 0 - ? selectedNode.meta.ips.join(', ') - : (selectedNode.meta.type === 'ip' ? selectedNode.label : Unknown)} - - - {/* OS */} - Operating System - - {selectedNode.meta.os.length > 0 - ? selectedNode.meta.os.join(', ') - : Unknown} - - - - - {/* Stats */} - - - - - - {/* Datasets */} - {selectedNode.meta.datasets.length > 0 && ( - - Seen in datasets - - {selectedNode.meta.datasets.map(d => ( - - ))} + + + + + + + {selectedNode.meta.hostname || selectedNode.label} + + - - )} + + + + + + + {selectedNode.meta.fqdn && ( + + + FQDN + + + {selectedNode.meta.fqdn} + + + )} + + + + IP Address + + + {selectedNode.meta.ips.length > 0 ? selectedNode.meta.ips.join(', ') : No IP detected} + + + + + + Operating System + + + {selectedNode.meta.os || Unknown} + + + + + + Logged-In Users + + {selectedNode.meta.users.length > 0 ? ( + + {selectedNode.meta.users.map(u => ( + + ))} + + ) : ( + + No user data + + )} + + + {selectedNode.meta.client_id && ( + + + Client ID + + + {selectedNode.meta.client_id} + + + )} + + + + + + + + + + + {connectedNodes.length > 0 && ( + + + Connected To + + + {connectedNodes.map(cn => ( + 30 ? cn.id.slice(0, 28) + '\u2026' : cn.id} size="small" + sx={{ + fontSize: 10, height: 22, fontFamily: 'monospace', + bgcolor: TYPE_COLORS[cn.type] + '15', color: TYPE_COLORS[cn.type], + border: `1px solid ${TYPE_COLORS[cn.type]}33`, cursor: 'pointer', + '&:hover': { bgcolor: TYPE_COLORS[cn.type] + '30' }, + }} + onClick={() => { setSearch(cn.id); closePopover(); }} + /> + ))} + + + )} + + {selectedNode.meta.datasets.length > 0 && ( + + + Seen In Datasets + + + {selectedNode.meta.datasets.map(d => ( + + ))} + + + )} + )} {/* Empty states */} {!selectedHuntId && !loading && ( - - + + + Select a hunt to visualize its network - - Choose a hunt from the dropdown above. The map will display IP addresses, - hostnames, and domains found across the hunt's datasets, with connections - showing co-occurrence in the same log rows. + + Choose a hunt from the dropdown above. The map builds a clean, + deduplicated host inventory showing each endpoint with its hostname, + IP address, OS, and logged-in users. )} {selectedHuntId && !graph && !loading && !error && ( - + - No network data to display. Upload datasets with IP/hostname columns to this hunt. + No host data to display. Upload datasets with host-identifying columns (ClientId, Fqdn, Hostname). )} ); -} +} \ No newline at end of file diff --git a/update.md b/update.md new file mode 100644 index 0000000..c278fe2 --- /dev/null +++ b/update.md @@ -0,0 +1,41 @@ +# ThreatHunt Update Log + +## 2026-02-20: Host-Centric Network Map & Analysis Platform + +### Network Map Overhaul +- **Problem**: Network Map showed 409 misclassified "domain" nodes (mostly process names like svchost.exe) and 0 hosts. No deduplication same host counted once per dataset. +- **Root Cause**: IOC column detection misclassified `Fqdn` as "domain" instead of "hostname"; `Name` column (process names) wrongly tagged as "domain" IOC; `ClientId` was in `normalized_columns` as "hostname" but not in `ioc_columns`. +- **Solution**: Created a new host-centric inventory system that scans all datasets, groups by `Fqdn`/`ClientId`, and extracts IPs, users, OS, and network connections. + +#### New Backend Files +- `backend/app/services/host_inventory.py` Deduplicated host inventory builder. Scans all datasets in a hunt, identifies unique hosts via regex-based column detection (`ClientId`, `Fqdn`, `User`/`Username`, `Laddr.IP`/`Raddr.IP`), groups rows, extracts metadata. Filters system accounts (DWM-*, UMFD-*, LOCAL SERVICE, NETWORK SERVICE). Infers OS from hostname patterns (W10-* Windows 10). Builds network connection graph from netstat remote IPs. +- `backend/app/api/routes/network.py` `GET /api/network/host-inventory?hunt_id=X` endpoint returning `{hosts, connections, stats}`. +- `backend/app/services/ioc_extractor.py` IOC extraction service (IP, domain, hash, email, URL patterns). +- `backend/app/services/anomaly_detector.py` Statistical anomaly detection across datasets. +- `backend/app/services/data_query.py` Natural language to structured query translation. +- `backend/app/services/load_balancer.py` Round-robin load balancer for Ollama LLM nodes. +- `backend/app/services/job_queue.py` Async job queue for long-running analysis tasks. +- `backend/app/api/routes/analysis.py` 16 analysis endpoints (IOC extraction, anomaly detection, host profiling, triage, reports, job management). + +#### Modified Backend Files +- `backend/app/main.py` Added `network_router` and `analysis_router` includes. +- `backend/app/db/models.py` Added 4 AI/analysis ORM models (`ProcessingJob`, `AnalysisResult`, `HostProfile`, `IOCEntry`). +- `backend/app/db/engine.py` Connection pool tuning for SQLite async. + +#### Frontend Changes +- `frontend/src/components/NetworkMap.tsx` Complete rewrite: host-centric force-directed graph using Canvas 2D. Two node types (Host / External IP). Shows hostname, IP, OS in labels. Click popover shows FQDN, IPs, OS, logged-in users, datasets, connections. Search across hostname/IP/user/OS. Stats cards showing host counts. +- `frontend/src/components/AnalysisDashboard.tsx` New 6-tab analysis dashboard (IOC Extraction, Anomaly Detection, Host Profiling, Query, Triage, Reports). +- `frontend/src/api/client.ts` Added `network.hostInventory()` method + `InventoryHost`, `InventoryConnection`, `InventoryStats` types. Added analysis API namespace with 16 endpoint methods. +- `frontend/src/App.tsx` Added Analysis Dashboard route and navigation. + +### Results (Radio Hunt 20 Velociraptor datasets, 394K rows) + +| Metric | Before | After | +|--------|--------|-------| +| Nodes shown | 409 misclassified "domains" | **163 unique hosts** | +| Hosts identified | 0 | **163** | +| With IP addresses | N/A | **48** (172.17.x.x LAN) | +| With logged-in users | N/A | **43** (real names only) | +| OS detected | None | **Windows 10** (inferred from hostnames) | +| Deduplication | None (same host 20 datasets) | **Full** (by FQDN/ClientId) | +| System account filtering | None | **DWM-*, UMFD-*, LOCAL/NETWORK SERVICE removed** | \ No newline at end of file