diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py new file mode 100644 index 0000000..700524d --- /dev/null +++ b/backend/alembic/versions/a1b2c3d4e5f6_add_processing_status_ai_tables.py @@ -0,0 +1,112 @@ +"""add processing_status and AI analysis tables + +Revision ID: a1b2c3d4e5f6 +Revises: 98ab619418bc +Create Date: 2026-02-19 18:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "98ab619418bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add columns to datasets table + with op.batch_alter_table("datasets") as batch_op: + batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready")) + batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True)) + batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True)) + batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True)) + batch_op.create_index("ix_datasets_status", ["processing_status"]) + + # Create triage_results table + op.create_table( + "triage_results", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("row_start", sa.Integer(), nullable=False), + sa.Column("row_end", sa.Integer(), nullable=False), + sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"), + sa.Column("findings", sa.JSON(), nullable=True), + sa.Column("suspicious_indicators", sa.JSON(), nullable=True), + sa.Column("mitre_techniques", sa.JSON(), nullable=True), + sa.Column("model_used", sa.String(128), nullable=True), + sa.Column("node_used", sa.String(64), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create host_profiles table + op.create_table( + "host_profiles", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("hostname", sa.String(256), nullable=False), + sa.Column("fqdn", sa.String(512), nullable=True), + sa.Column("client_id", sa.String(64), nullable=True), + sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"), + sa.Column("artifact_summary", sa.JSON(), nullable=True), + sa.Column("timeline_summary", sa.Text(), nullable=True), + sa.Column("suspicious_findings", sa.JSON(), nullable=True), + sa.Column("mitre_techniques", sa.JSON(), nullable=True), + sa.Column("llm_analysis", sa.Text(), nullable=True), + sa.Column("model_used", sa.String(128), nullable=True), + sa.Column("node_used", sa.String(64), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create hunt_reports table + op.create_table( + "hunt_reports", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("status", sa.String(20), nullable=False, server_default="pending"), + sa.Column("exec_summary", sa.Text(), nullable=True), + sa.Column("full_report", sa.Text(), nullable=True), + sa.Column("findings", sa.JSON(), nullable=True), + sa.Column("recommendations", sa.JSON(), nullable=True), + sa.Column("mitre_mapping", sa.JSON(), nullable=True), + sa.Column("ioc_table", sa.JSON(), nullable=True), + sa.Column("host_risk_summary", sa.JSON(), nullable=True), + sa.Column("models_used", sa.JSON(), nullable=True), + sa.Column("generation_time_ms", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create anomaly_results table + op.create_table( + "anomaly_results", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True), + sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("distance_from_centroid", sa.Float(), nullable=True), + sa.Column("cluster_id", sa.Integer(), nullable=True), + sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("explanation", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + +def downgrade() -> None: + op.drop_table("anomaly_results") + op.drop_table("hunt_reports") + op.drop_table("host_profiles") + op.drop_table("triage_results") + + with op.batch_alter_table("datasets") as batch_op: + batch_op.drop_index("ix_datasets_status") + batch_op.drop_column("file_path") + batch_op.drop_column("error_message") + batch_op.drop_column("artifact_type") + batch_op.drop_column("processing_status") \ No newline at end of file diff --git a/backend/app/db/engine.py b/backend/app/db/engine.py index d8b2840..95e67f8 100644 --- a/backend/app/db/engine.py +++ b/backend/app/db/engine.py @@ -3,6 +3,7 @@ Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL. """ +from sqlalchemy import event from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, @@ -12,12 +13,32 @@ from sqlalchemy.orm import DeclarativeBase from app.config import settings -engine = create_async_engine( - settings.DATABASE_URL, +_is_sqlite = settings.DATABASE_URL.startswith("sqlite") + +_engine_kwargs: dict = dict( echo=settings.DEBUG, future=True, ) +if _is_sqlite: + _engine_kwargs["connect_args"] = {"timeout": 30} + _engine_kwargs["pool_size"] = 1 + _engine_kwargs["max_overflow"] = 0 + +engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs) + + +@event.listens_for(engine.sync_engine, "connect") +def _set_sqlite_pragmas(dbapi_conn, connection_record): + """Enable WAL mode and tune busy-timeout for SQLite connections.""" + if _is_sqlite: + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA busy_timeout=5000") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + async_session_factory = async_sessionmaker( engine, class_=AsyncSession, @@ -63,4 +84,4 @@ async def init_db() -> None: async def dispose_db() -> None: """Dispose of the engine connection pool.""" - await engine.dispose() + await engine.dispose() \ No newline at end of file diff --git a/backend/app/services/anomaly_detector.py b/backend/app/services/anomaly_detector.py new file mode 100644 index 0000000..69b7593 --- /dev/null +++ b/backend/app/services/anomaly_detector.py @@ -0,0 +1,199 @@ +"""Embedding-based anomaly detection using Roadrunner's bge-m3 model. + +Converts dataset rows to embeddings, clusters them, and flags outliers +that deviate significantly from the cluster centroids. Uses cosine +distance and simple k-means-like centroid computation. +""" + +import asyncio +import json +import logging +import math +from typing import Optional + +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.db import async_session_factory +from app.db.models import AnomalyResult, Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +EMBED_URL = f"{settings.roadrunner_url}/api/embed" +EMBED_MODEL = "bge-m3" +BATCH_SIZE = 32 # rows per embedding batch +MAX_ROWS = 2000 # cap for anomaly detection + +# --- math helpers (no numpy required) --- + +def _dot(a: list[float], b: list[float]) -> float: + return sum(x * y for x, y in zip(a, b)) + + +def _norm(v: list[float]) -> float: + return math.sqrt(sum(x * x for x in v)) + + +def _cosine_distance(a: list[float], b: list[float]) -> float: + na, nb = _norm(a), _norm(b) + if na == 0 or nb == 0: + return 1.0 + return 1.0 - _dot(a, b) / (na * nb) + + +def _mean_vector(vectors: list[list[float]]) -> list[float]: + if not vectors: + return [] + dim = len(vectors[0]) + n = len(vectors) + return [sum(v[i] for v in vectors) / n for i in range(dim)] + + +def _row_to_text(data: dict) -> str: + """Flatten a row dict to a single string for embedding.""" + parts = [] + for k, v in data.items(): + sv = str(v).strip() + if sv and sv.lower() not in ('none', 'null', ''): + parts.append(f"{k}={sv}") + return " | ".join(parts)[:2000] # cap length + + +async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]: + """Get embeddings from Roadrunner's Ollama API.""" + resp = await client.post( + EMBED_URL, + json={"model": EMBED_MODEL, "input": texts}, + timeout=120.0, + ) + resp.raise_for_status() + data = resp.json() + # Ollama returns {"embeddings": [[...], ...]} + return data.get("embeddings", []) + + +def _simple_cluster( + embeddings: list[list[float]], + k: int = 3, + max_iter: int = 20, +) -> tuple[list[int], list[list[float]]]: + """Simple k-means clustering (no numpy dependency). + + Returns (assignments, centroids). + """ + n = len(embeddings) + if n <= k: + return list(range(n)), embeddings[:] + + # Init centroids: evenly spaced indices + step = max(n // k, 1) + centroids = [embeddings[i * step % n] for i in range(k)] + assignments = [0] * n + + for _ in range(max_iter): + # Assign to nearest centroid + new_assignments = [] + for emb in embeddings: + dists = [_cosine_distance(emb, c) for c in centroids] + new_assignments.append(dists.index(min(dists))) + + if new_assignments == assignments: + break + assignments = new_assignments + + # Recompute centroids + for ci in range(k): + members = [embeddings[j] for j in range(n) if assignments[j] == ci] + if members: + centroids[ci] = _mean_vector(members) + + return assignments, centroids + + +async def detect_anomalies( + dataset_id: str, + k: int = 3, + outlier_threshold: float = 0.35, +) -> list[dict]: + """Run embedding-based anomaly detection on a dataset. + + 1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers. + """ + async with async_session_factory() as db: + # Load rows + result = await db.execute( + select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .limit(MAX_ROWS) + ) + rows = result.all() + if not rows: + logger.info("No rows for anomaly detection in dataset %s", dataset_id) + return [] + + row_ids = [r[0] for r in rows] + row_indices = [r[1] for r in rows] + texts = [_row_to_text(r[2]) for r in rows] + + logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL) + + # Embed in batches + all_embeddings: list[list[float]] = [] + async with httpx.AsyncClient() as client: + for i in range(0, len(texts), BATCH_SIZE): + batch = texts[i : i + BATCH_SIZE] + try: + embs = await _embed_batch(batch, client) + all_embeddings.extend(embs) + except Exception as e: + logger.error("Embedding batch %d failed: %s", i, e) + # Fill with zeros so indices stay aligned + all_embeddings.extend([[0.0] * 1024] * len(batch)) + + if not all_embeddings or len(all_embeddings) != len(texts): + logger.error("Embedding count mismatch") + return [] + + # Cluster + actual_k = min(k, len(all_embeddings)) + assignments, centroids = _simple_cluster(all_embeddings, k=actual_k) + + # Compute distances from centroid + anomalies: list[dict] = [] + for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)): + dist = _cosine_distance(emb, centroids[cluster_id]) + is_outlier = dist > outlier_threshold + anomalies.append({ + "row_id": row_ids[idx], + "row_index": row_indices[idx], + "anomaly_score": round(dist, 4), + "distance_from_centroid": round(dist, 4), + "cluster_id": cluster_id, + "is_outlier": is_outlier, + }) + + # Save to DB + outlier_count = 0 + for a in anomalies: + ar = AnomalyResult( + dataset_id=dataset_id, + row_id=a["row_id"], + anomaly_score=a["anomaly_score"], + distance_from_centroid=a["distance_from_centroid"], + cluster_id=a["cluster_id"], + is_outlier=a["is_outlier"], + ) + db.add(ar) + if a["is_outlier"]: + outlier_count += 1 + + await db.commit() + logger.info( + "Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)", + len(anomalies), outlier_count, outlier_threshold, + ) + + return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True) \ No newline at end of file diff --git a/backend/app/services/artifact_classifier.py b/backend/app/services/artifact_classifier.py new file mode 100644 index 0000000..d88bf32 --- /dev/null +++ b/backend/app/services/artifact_classifier.py @@ -0,0 +1,81 @@ +"""Artifact classifier - identify Velociraptor artifact types from CSV headers.""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# (required_columns, artifact_type) +FINGERPRINTS: list[tuple[set[str], str]] = [ + ({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"), + ({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"), + ({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"), + ({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"), + ({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"), + ({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"), + ({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"), + ({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"), + ({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"), + ({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"), + ({"Key", "Type", "Value"}, "Windows.Registry.Finder"), + ({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"), + ({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"), + ({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"), + ({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"), + ({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"), + ({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"), + ({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"), + ({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"), + ({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"), + ({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"), + ({"os_info.hostname", "os_info.system"}, "Server.Information.Client"), + ({"ClientId", "os_info.fqdn"}, "Server.Information.Client"), + ({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"), + ({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"), + ({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"), + ({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"), + ({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"), + ({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"), +] + +VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"} + +CATEGORY_MAP = { + "Pslist": "process", + "Netstat": "network", + "Services": "persistence", + "FileFinder": "filesystem", + "Prefetch": "execution", + "Registry": "persistence", + "EvtxHunter": "eventlog", + "EventLogs": "eventlog", + "Autoruns": "persistence", + "DNS": "network", + "Hash": "filesystem", + "TaskScheduler": "persistence", + "Users": "account", + "Client": "system", + "WMI": "persistence", + "ArpCache": "network", + "BrowserHistory": "application", +} + + +def classify_artifact(columns: list[str]) -> str: + col_set = set(columns) + for required, artifact_type in FINGERPRINTS: + if required.issubset(col_set): + return artifact_type + if VELOCIRAPTOR_META.intersection(col_set): + return "Velociraptor.Unknown" + return "Unknown" + + +def get_artifact_category(artifact_type: str) -> str: + for key, category in CATEGORY_MAP.items(): + if key.lower() in artifact_type.lower(): + return category + return "unknown" \ No newline at end of file diff --git a/backend/app/services/data_query.py b/backend/app/services/data_query.py new file mode 100644 index 0000000..f13cdf5 --- /dev/null +++ b/backend/app/services/data_query.py @@ -0,0 +1,238 @@ +"""Natural-language data query service with SSE streaming. + +Lets analysts ask questions about dataset rows in plain English. +Routes to fast model (Roadrunner) for quick queries, heavy model (Wile) +for deep analysis. Supports streaming via OllamaProvider.generate_stream(). +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from typing import AsyncIterator + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config import settings +from app.db import async_session_factory +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# Maximum rows to include in context window +MAX_CONTEXT_ROWS = 60 +MAX_ROW_TEXT_CHARS = 300 + + +def _rows_to_text(rows: list[dict], columns: list[str]) -> str: + """Convert dataset rows to a compact text table for the LLM context.""" + if not rows: + return "(no rows)" + # Header + header = " | ".join(columns[:20]) # cap columns to avoid overflow + lines = [header, "-" * min(len(header), 120)] + for row in rows[:MAX_CONTEXT_ROWS]: + vals = [] + for c in columns[:20]: + v = str(row.get(c, "")) + if len(v) > 80: + v = v[:77] + "..." + vals.append(v) + line = " | ".join(vals) + if len(line) > MAX_ROW_TEXT_CHARS: + line = line[:MAX_ROW_TEXT_CHARS] + "..." + lines.append(line) + return "\n".join(lines) + + +QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt. +You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.). + +Your job: +- Answer the analyst's question about this data accurately and concisely +- Point out suspicious patterns, anomalies, or indicators of compromise +- Reference MITRE ATT&CK techniques when relevant +- Suggest follow-up queries or pivots +- If you cannot answer from the data provided, say so clearly + +Rules: +- Be factual - only reference data you can see +- Use forensic terminology appropriate for SOC/DFIR analysts +- Format your answer with clear sections using markdown +- If the data seems benign, say so - do not fabricate threats""" + + +async def _load_dataset_context( + dataset_id: str, + db: AsyncSession, + sample_size: int = MAX_CONTEXT_ROWS, +) -> tuple[dict, str, int]: + """Load dataset metadata + sample rows for context. + + Returns (metadata_dict, rows_text, total_row_count). + """ + ds = await db.get(Dataset, dataset_id) + if not ds: + raise ValueError(f"Dataset {dataset_id} not found") + + # Get total count + count_q = await db.execute( + select(func.count()).where(DatasetRow.dataset_id == dataset_id) + ) + total = count_q.scalar() or 0 + + # Sample rows - get first batch + some from the middle + half = sample_size // 2 + result = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .limit(half) + ) + first_rows = result.scalars().all() + + # If dataset is large, also sample from the middle + middle_rows = [] + if total > sample_size: + mid_offset = total // 2 + result2 = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(mid_offset) + .limit(sample_size - half) + ) + middle_rows = result2.scalars().all() + else: + result2 = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(half) + .limit(sample_size - half) + ) + middle_rows = result2.scalars().all() + + all_rows = first_rows + middle_rows + row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows] + + columns = list(ds.column_schema.keys()) if ds.column_schema else [] + if not columns and row_dicts: + columns = list(row_dicts[0].keys()) + + rows_text = _rows_to_text(row_dicts, columns) + + metadata = { + "name": ds.name, + "filename": ds.filename, + "source_tool": ds.source_tool, + "artifact_type": getattr(ds, "artifact_type", None), + "row_count": total, + "columns": columns[:30], + "sample_rows_shown": len(all_rows), + } + return metadata, rows_text, total + + +async def query_dataset( + dataset_id: str, + question: str, + mode: str = "quick", +) -> str: + """Non-streaming query: returns full answer text.""" + from app.agents.providers_v2 import OllamaProvider, Node + + async with async_session_factory() as db: + meta, rows_text, total = await _load_dataset_context(dataset_id, db) + + prompt = _build_prompt(question, meta, rows_text, total) + + if mode == "deep": + provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE) + max_tokens = 4096 + else: + provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER) + max_tokens = 2048 + + result = await provider.generate( + prompt, + system=QUERY_SYSTEM_PROMPT, + max_tokens=max_tokens, + temperature=0.3, + ) + return result.get("response", "No response generated.") + + +async def query_dataset_stream( + dataset_id: str, + question: str, + mode: str = "quick", +) -> AsyncIterator[str]: + """Streaming query: yields SSE-formatted events.""" + from app.agents.providers_v2 import OllamaProvider, Node + + start = time.monotonic() + + # Send initial metadata event + yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n" + + async with async_session_factory() as db: + meta, rows_text, total = await _load_dataset_context(dataset_id, db) + + yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n" + yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n" + + prompt = _build_prompt(question, meta, rows_text, total) + + if mode == "deep": + provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE) + max_tokens = 4096 + model_name = settings.DEFAULT_HEAVY_MODEL + node_name = "wile" + else: + provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER) + max_tokens = 2048 + model_name = settings.DEFAULT_FAST_MODEL + node_name = "roadrunner" + + # Stream tokens + token_count = 0 + try: + async for token in provider.generate_stream( + prompt, + system=QUERY_SYSTEM_PROMPT, + max_tokens=max_tokens, + temperature=0.3, + ): + token_count += 1 + yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n" + except Exception as e: + logger.error(f"Streaming error: {e}") + yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n" + + elapsed_ms = int((time.monotonic() - start) * 1000) + yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n" + + +def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str: + """Construct the full prompt with data context.""" + parts = [ + f"## Dataset: {meta['name']}", + f"- Source: {meta.get('source_tool', 'unknown')}", + f"- Artifact type: {meta.get('artifact_type', 'unknown')}", + f"- Total rows: {total}", + f"- Columns: {', '.join(meta.get('columns', []))}", + f"- Showing {meta['sample_rows_shown']} sample rows below", + "", + "## Sample Data", + "```", + rows_text, + "```", + "", + f"## Analyst Question", + question, + ] + return "\n".join(parts) \ No newline at end of file diff --git a/backend/app/services/host_inventory.py b/backend/app/services/host_inventory.py new file mode 100644 index 0000000..32f0146 --- /dev/null +++ b/backend/app/services/host_inventory.py @@ -0,0 +1,290 @@ +"""Host Inventory Service - builds a deduplicated host-centric network view. + +Scans all datasets in a hunt to identify unique hosts, their IPs, OS, +logged-in users, and network connections between them. +""" + +import re +import logging +from collections import defaultdict +from typing import Any + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# --- Column-name patterns (Velociraptor + generic forensic tools) --- + +_HOST_ID_RE = re.compile( + r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I) +_FQDN_RE = re.compile( + r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|' + r'system_?name|machine_?name|nodename|workstation)$', re.I) +_USERNAME_RE = re.compile( + r'^(user|username|user_?name|logon_?name|account_?name|owner|' + r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I) +_LOCAL_IP_RE = re.compile( + r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I) +_REMOTE_IP_RE = re.compile( + r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I) +_REMOTE_PORT_RE = re.compile( + r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I) +_OS_RE = re.compile( + r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I) +_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$') + +_IGNORE_IPS = frozenset({ + '0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null', +}) +_SYSTEM_DOMAINS = frozenset({ + 'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER', +}) +_SYSTEM_USERS = frozenset({ + 'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE', + 'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3', +}) + + +def _is_valid_ip(v: str) -> bool: + if not v or v in _IGNORE_IPS: + return False + return bool(_IP_VALID_RE.match(v)) + + +def _clean(v: Any) -> str: + s = str(v or '').strip() + return s if s and s not in ('-', 'None', 'null', '') else '' + + +_SYSTEM_USER_RE = re.compile( + r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I) + + +def _extract_username(raw: str) -> str: + """Clean username, stripping domain prefixes and filtering system accounts.""" + if not raw: + return '' + name = raw.strip() + if '\\' in name: + domain, _, name = name.rpartition('\\') + name = name.strip() + if domain.strip().upper() in _SYSTEM_DOMAINS: + if not name or _SYSTEM_USER_RE.match(name): + return '' + if _SYSTEM_USER_RE.match(name): + return '' + return name or '' + + +def _infer_os(fqdn: str) -> str: + u = fqdn.upper() + if 'W10-' in u or 'WIN10' in u: + return 'Windows 10' + if 'W11-' in u or 'WIN11' in u: + return 'Windows 11' + if 'W7-' in u or 'WIN7' in u: + return 'Windows 7' + if 'SRV' in u or 'SERVER' in u or 'DC-' in u: + return 'Windows Server' + if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')): + return 'Linux' + if 'MAC' in u or 'DARWIN' in u: + return 'macOS' + return 'Windows' + + +def _identify_columns(ds: Dataset) -> dict: + norm = ds.normalized_columns or {} + schema = ds.column_schema or {} + raw_cols = list(schema.keys()) if schema else list(norm.keys()) + + result = { + 'host_id': [], 'fqdn': [], 'username': [], + 'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [], + } + + for col in raw_cols: + canonical = (norm.get(col) or '').lower() + lower = col.lower() + + if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')): + result['host_id'].append(col) + + if _FQDN_RE.match(lower) or canonical == 'fqdn': + result['fqdn'].append(col) + + if _USERNAME_RE.match(lower) or canonical in ('username', 'user'): + result['username'].append(col) + + if _LOCAL_IP_RE.match(lower): + result['local_ip'].append(col) + elif _REMOTE_IP_RE.match(lower): + result['remote_ip'].append(col) + + if _REMOTE_PORT_RE.match(lower): + result['remote_port'].append(col) + + if _OS_RE.match(lower) or canonical == 'os': + result['os'].append(col) + + return result + + +async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: + """Build a deduplicated host inventory from all datasets in a hunt. + + Returns dict with 'hosts', 'connections', and 'stats'. + Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count. + """ + ds_result = await db.execute( + select(Dataset).where(Dataset.hunt_id == hunt_id) + ) + all_datasets = ds_result.scalars().all() + + if not all_datasets: + return {"hosts": [], "connections": [], "stats": { + "total_hosts": 0, "total_datasets_scanned": 0, + "total_rows_scanned": 0, + }} + + hosts: dict[str, dict] = {} # fqdn -> host record + ip_to_host: dict[str, str] = {} # local-ip -> fqdn + connections: dict[tuple, int] = defaultdict(int) + total_rows = 0 + ds_with_hosts = 0 + + for ds in all_datasets: + cols = _identify_columns(ds) + if not cols['fqdn'] and not cols['host_id']: + continue + ds_with_hosts += 1 + + batch_size = 5000 + offset = 0 + while True: + rr = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .order_by(DatasetRow.row_index) + .offset(offset).limit(batch_size) + ) + rows = rr.scalars().all() + if not rows: + break + + for ro in rows: + data = ro.data or {} + total_rows += 1 + + fqdn = '' + for c in cols['fqdn']: + fqdn = _clean(data.get(c)) + if fqdn: + break + client_id = '' + for c in cols['host_id']: + client_id = _clean(data.get(c)) + if client_id: + break + + if not fqdn and not client_id: + continue + + host_key = fqdn or client_id + + if host_key not in hosts: + short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn + hosts[host_key] = { + 'id': host_key, + 'hostname': short or client_id, + 'fqdn': fqdn, + 'client_id': client_id, + 'ips': set(), + 'os': '', + 'users': set(), + 'datasets': set(), + 'row_count': 0, + } + + h = hosts[host_key] + h['datasets'].add(ds.name) + h['row_count'] += 1 + if client_id and not h['client_id']: + h['client_id'] = client_id + + for c in cols['username']: + u = _extract_username(_clean(data.get(c))) + if u: + h['users'].add(u) + + for c in cols['local_ip']: + ip = _clean(data.get(c)) + if _is_valid_ip(ip): + h['ips'].add(ip) + ip_to_host[ip] = host_key + + for c in cols['os']: + ov = _clean(data.get(c)) + if ov and not h['os']: + h['os'] = ov + + for c in cols['remote_ip']: + rip = _clean(data.get(c)) + if _is_valid_ip(rip): + rport = '' + for pc in cols['remote_port']: + rport = _clean(data.get(pc)) + if rport: + break + connections[(host_key, rip, rport)] += 1 + + offset += batch_size + if len(rows) < batch_size: + break + + # Post-process hosts + for h in hosts.values(): + if not h['os'] and h['fqdn']: + h['os'] = _infer_os(h['fqdn']) + h['ips'] = sorted(h['ips']) + h['users'] = sorted(h['users']) + h['datasets'] = sorted(h['datasets']) + + # Build connections, resolving IPs to host keys + conn_list = [] + seen = set() + for (src, dst_ip, dst_port), cnt in connections.items(): + if dst_ip in _IGNORE_IPS: + continue + dst_host = ip_to_host.get(dst_ip, '') + if dst_host == src: + continue + key = tuple(sorted([src, dst_host or dst_ip])) + if key in seen: + continue + seen.add(key) + conn_list.append({ + 'source': src, + 'target': dst_host or dst_ip, + 'target_ip': dst_ip, + 'port': dst_port, + 'count': cnt, + }) + + host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True) + + return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + }, + } \ No newline at end of file diff --git a/backend/app/services/host_profiler.py b/backend/app/services/host_profiler.py new file mode 100644 index 0000000..c4f8bb1 --- /dev/null +++ b/backend/app/services/host_profiler.py @@ -0,0 +1,198 @@ +"""Host profiler - per-host deep threat analysis via Wile heavy models.""" + +from __future__ import annotations + +import asyncio +import json +import logging + +import httpx +from sqlalchemy import select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult + +logger = logging.getLogger(__name__) + +HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL +WILE_URL = f"{settings.wile_url}/api/generate" + + +async def _get_triage_summary(db, dataset_id: str) -> str: + result = await db.execute( + select(TriageResult) + .where(TriageResult.dataset_id == dataset_id) + .where(TriageResult.risk_score >= 3.0) + .order_by(TriageResult.risk_score.desc()) + .limit(10) + ) + triages = result.scalars().all() + if not triages: + return "No significant triage findings." + lines = [] + for t in triages: + lines.append( + f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} " + f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}" + ) + return "\n".join(lines) + + +async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict: + result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = result.scalars().all() + + host_data: dict[str, list[dict]] = {} + triage_parts: list[str] = [] + + for ds in datasets: + artifact_type = getattr(ds, "artifact_type", None) or "Unknown" + rows_result = await db.execute( + select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500) + ) + rows = rows_result.scalars().all() + + matching = [] + for r in rows: + data = r.normalized_data or r.data + row_host = ( + data.get("hostname", "") or data.get("Fqdn", "") + or data.get("ClientId", "") or data.get("client_id", "") + ) + if hostname.lower() in str(row_host).lower(): + matching.append(data) + elif fqdn and fqdn.lower() in str(row_host).lower(): + matching.append(data) + + if matching: + host_data[artifact_type] = matching[:50] + triage_info = await _get_triage_summary(db, ds.id) + triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}") + + return { + "artifacts": host_data, + "triage_summary": "\n".join(triage_parts) or "No triage data.", + "artifact_count": sum(len(v) for v in host_data.values()), + } + + +async def profile_host( + hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None, +) -> None: + logger.info("Profiling host %s in hunt %s", hostname, hunt_id) + + async with async_session() as db: + host_data = await _collect_host_data(db, hunt_id, hostname, fqdn) + if host_data["artifact_count"] == 0: + logger.info("No data found for host %s, skipping", hostname) + return + + system_prompt = ( + "You are a senior threat hunting analyst performing deep host analysis.\n" + "You receive consolidated forensic artifacts and prior triage results for a single host.\n\n" + "Provide a comprehensive host threat profile as JSON:\n" + "- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n" + "- risk_level: low/medium/high/critical\n" + "- suspicious_findings: list of specific concerns\n" + "- mitre_techniques: list of MITRE ATT&CK technique IDs\n" + "- timeline_summary: brief timeline of suspicious activity\n" + "- analysis: detailed narrative assessment\n\n" + "Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n" + "Respond with valid JSON only." + ) + + artifact_summary = {} + for art_type, rows in host_data["artifacts"].items(): + artifact_summary[art_type] = [ + {k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20] + ] + + prompt = ( + f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n" + f"## Prior Triage Results\n{host_data['triage_summary']}\n\n" + f"## Artifact Data ({host_data['artifact_count']} total rows)\n" + f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n" + "Provide your comprehensive host threat profile as JSON." + ) + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + resp = await client.post( + WILE_URL, + json={ + "model": HEAVY_MODEL, + "prompt": prompt, + "system": system_prompt, + "stream": False, + "options": {"temperature": 0.3, "num_predict": 4096}, + }, + ) + resp.raise_for_status() + llm_text = resp.json().get("response", "") + + from app.services.triage import _parse_llm_response + parsed = _parse_llm_response(llm_text) + + profile = HostProfile( + hunt_id=hunt_id, + hostname=hostname, + fqdn=fqdn, + client_id=client_id, + risk_score=float(parsed.get("risk_score", 0.0)), + risk_level=parsed.get("risk_level", "low"), + artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()}, + timeline_summary=parsed.get("timeline_summary", ""), + suspicious_findings=parsed.get("suspicious_findings", []), + mitre_techniques=parsed.get("mitre_techniques", []), + llm_analysis=parsed.get("analysis", llm_text[:5000]), + model_used=HEAVY_MODEL, + node_used="wile", + ) + db.add(profile) + await db.commit() + logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level) + + except Exception as e: + logger.error("Failed to profile host %s: %s", hostname, e) + profile = HostProfile( + hunt_id=hunt_id, hostname=hostname, fqdn=fqdn, + risk_score=0.0, risk_level="unknown", + llm_analysis=f"Error: {e}", + model_used=HEAVY_MODEL, node_used="wile", + ) + db.add(profile) + await db.commit() + + +async def profile_all_hosts(hunt_id: str) -> None: + logger.info("Starting host profiling for hunt %s", hunt_id) + + async with async_session() as db: + result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = result.scalars().all() + + hostnames: dict[str, str | None] = {} + for ds in datasets: + rows_result = await db.execute( + select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000) + ) + for r in rows_result.scalars().all(): + data = r.normalized_data or r.data + host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname") + if host and str(host).strip(): + h = str(host).strip() + if h not in hostnames: + hostnames[h] = data.get("fqdn") or data.get("Fqdn") + + logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id) + + semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY) + + async def _bounded(hostname: str, fqdn: str | None): + async with semaphore: + await profile_host(hunt_id, hostname, fqdn) + + tasks = [_bounded(h, f) for h, f in hostnames.items()] + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames)) \ No newline at end of file diff --git a/backend/app/services/ioc_extractor.py b/backend/app/services/ioc_extractor.py new file mode 100644 index 0000000..889bfa7 --- /dev/null +++ b/backend/app/services/ioc_extractor.py @@ -0,0 +1,210 @@ +"""IOC extraction service extract indicators of compromise from dataset rows. + +Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes, +email addresses, URLs, and file paths that look suspicious. +""" + +import re +import logging +from collections import defaultdict +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import Dataset, DatasetRow + +logger = logging.getLogger(__name__) + +# Patterns + +_IPV4 = re.compile( + r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b' +) +_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b') +_DOMAIN = re.compile( + r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)' + r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|' + r'online|site|club|win|work|download|stream|gdn|bid|review|racing|' + r'loan|date|faith|accountant|cricket|science|trade|party|men)\b', + re.IGNORECASE, +) +_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b') +_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b') +_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b') +_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b') +_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE) + +# Private / reserved IPs to skip +_PRIVATE_NETS = re.compile( + r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)' +) + +PATTERNS = { + 'ipv4': _IPV4, + 'ipv6': _IPV6, + 'domain': _DOMAIN, + 'md5': _MD5, + 'sha1': _SHA1, + 'sha256': _SHA256, + 'email': _EMAIL, + 'url': _URL, +} + + +def _is_private_ip(ip: str) -> bool: + return bool(_PRIVATE_NETS.match(ip)) + + +def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]: + """Extract all IOC types from a block of text.""" + result: dict[str, set[str]] = defaultdict(set) + for ioc_type, pattern in PATTERNS.items(): + for match in pattern.findall(text): + val = match.strip().lower() if ioc_type != 'url' else match.strip() + # Filter private IPs + if ioc_type == 'ipv4' and skip_private and _is_private_ip(val): + continue + # Filter hex strings that are too generic (< 32 chars not a hash) + result[ioc_type].add(val) + return result + + +async def extract_iocs_from_dataset( + dataset_id: str, + db: AsyncSession, + max_rows: int = 5000, + skip_private: bool = True, +) -> dict[str, list[str]]: + """Extract IOCs from all rows of a dataset. + + Returns {ioc_type: [sorted unique values]}. + """ + # Load rows in batches + all_iocs: dict[str, set[str]] = defaultdict(set) + offset = 0 + batch_size = 500 + + while offset < max_rows: + result = await db.execute( + select(DatasetRow.data) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_index) + .offset(offset) + .limit(batch_size) + ) + rows = result.scalars().all() + if not rows: + break + + for data in rows: + # Flatten all values to a single string for scanning + text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data) + batch_iocs = extract_iocs_from_text(text, skip_private) + for ioc_type, values in batch_iocs.items(): + all_iocs[ioc_type].update(values) + + offset += batch_size + + # Convert sets to sorted lists + return {k: sorted(v) for k, v in all_iocs.items() if v} + + +async def extract_host_groups( + hunt_id: str, + db: AsyncSession, +) -> list[dict]: + """Group all data by hostname across datasets in a hunt. + + Returns a list of host group dicts with dataset count, total rows, + artifact types, and time range. + """ + # Get all datasets for this hunt + result = await db.execute( + select(Dataset).where(Dataset.hunt_id == hunt_id) + ) + ds_list = result.scalars().all() + if not ds_list: + return [] + + # Known host columns (check normalized data first, then raw) + HOST_COLS = [ + 'hostname', 'host', 'computer_name', 'computername', 'system', + 'machine', 'device_name', 'devicename', 'endpoint', + 'ClientId', 'Fqdn', 'client_id', 'fqdn', + ] + + hosts: dict[str, dict] = {} + + for ds in ds_list: + # Sample first few rows to find host column + sample_result = await db.execute( + select(DatasetRow.data, DatasetRow.normalized_data) + .where(DatasetRow.dataset_id == ds.id) + .limit(5) + ) + samples = sample_result.all() + if not samples: + continue + + # Find which host column exists + host_col = None + for row_data, norm_data in samples: + check = norm_data if norm_data else row_data + if not isinstance(check, dict): + continue + for col in HOST_COLS: + if col in check and check[col]: + host_col = col + break + if host_col: + break + + if not host_col: + continue + + # Count rows per host in this dataset + all_rows_result = await db.execute( + select(DatasetRow.data, DatasetRow.normalized_data) + .where(DatasetRow.dataset_id == ds.id) + ) + all_rows = all_rows_result.all() + for row_data, norm_data in all_rows: + check = norm_data if norm_data else row_data + if not isinstance(check, dict): + continue + host_val = check.get(host_col, '') + if not host_val or not isinstance(host_val, str): + continue + host_val = host_val.strip() + if not host_val: + continue + + if host_val not in hosts: + hosts[host_val] = { + 'hostname': host_val, + 'dataset_ids': set(), + 'total_rows': 0, + 'artifact_types': set(), + 'first_seen': None, + 'last_seen': None, + } + hosts[host_val]['dataset_ids'].add(ds.id) + hosts[host_val]['total_rows'] += 1 + if ds.artifact_type: + hosts[host_val]['artifact_types'].add(ds.artifact_type) + + # Convert to output format + result_list = [] + for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True): + result_list.append({ + 'hostname': h['hostname'], + 'dataset_count': len(h['dataset_ids']), + 'total_rows': h['total_rows'], + 'artifact_types': sorted(h['artifact_types']), + 'first_seen': None, # TODO: extract from timestamp columns + 'last_seen': None, + 'risk_score': None, # TODO: link to host profiles + }) + + return result_list \ No newline at end of file diff --git a/backend/app/services/job_queue.py b/backend/app/services/job_queue.py new file mode 100644 index 0000000..f218cdf --- /dev/null +++ b/backend/app/services/job_queue.py @@ -0,0 +1,316 @@ +"""Async job queue for background AI tasks. + +Manages triage, profiling, report generation, anomaly detection, +and data queries as trackable jobs with status, progress, and +cancellation support. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Coroutine, Optional + +logger = logging.getLogger(__name__) + + +class JobStatus(str, Enum): + QUEUED = "queued" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobType(str, Enum): + TRIAGE = "triage" + HOST_PROFILE = "host_profile" + REPORT = "report" + ANOMALY = "anomaly" + QUERY = "query" + + +@dataclass +class Job: + id: str + job_type: JobType + status: JobStatus = JobStatus.QUEUED + progress: float = 0.0 # 0-100 + message: str = "" + result: Any = None + error: str | None = None + created_at: float = field(default_factory=time.time) + started_at: float | None = None + completed_at: float | None = None + params: dict = field(default_factory=dict) + _cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) + + @property + def elapsed_ms(self) -> int: + end = self.completed_at or time.time() + start = self.started_at or self.created_at + return int((end - start) * 1000) + + def to_dict(self) -> dict: + return { + "id": self.id, + "job_type": self.job_type.value, + "status": self.status.value, + "progress": round(self.progress, 1), + "message": self.message, + "error": self.error, + "created_at": self.created_at, + "started_at": self.started_at, + "completed_at": self.completed_at, + "elapsed_ms": self.elapsed_ms, + "params": self.params, + } + + @property + def is_cancelled(self) -> bool: + return self._cancel_event.is_set() + + def cancel(self): + self._cancel_event.set() + self.status = JobStatus.CANCELLED + self.completed_at = time.time() + self.message = "Cancelled by user" + + +class JobQueue: + """In-memory async job queue with concurrency control. + + Jobs are tracked by ID and can be listed, polled, or cancelled. + A configurable number of workers process jobs from the queue. + """ + + def __init__(self, max_workers: int = 3): + self._jobs: dict[str, Job] = {} + self._queue: asyncio.Queue[str] = asyncio.Queue() + self._max_workers = max_workers + self._workers: list[asyncio.Task] = [] + self._handlers: dict[JobType, Callable] = {} + self._started = False + + def register_handler( + self, + job_type: JobType, + handler: Callable[[Job], Coroutine], + ): + """Register an async handler for a job type. + + Handler signature: async def handler(job: Job) -> Any + The handler can update job.progress and job.message during execution. + It should check job.is_cancelled periodically and return early. + """ + self._handlers[job_type] = handler + logger.info(f"Registered handler for {job_type.value}") + + async def start(self): + """Start worker tasks.""" + if self._started: + return + self._started = True + for i in range(self._max_workers): + task = asyncio.create_task(self._worker(i)) + self._workers.append(task) + logger.info(f"Job queue started with {self._max_workers} workers") + + async def stop(self): + """Stop all workers.""" + self._started = False + for w in self._workers: + w.cancel() + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + logger.info("Job queue stopped") + + def submit(self, job_type: JobType, **params) -> Job: + """Submit a new job. Returns the Job object immediately.""" + job = Job( + id=str(uuid.uuid4()), + job_type=job_type, + params=params, + ) + self._jobs[job.id] = job + self._queue.put_nowait(job.id) + logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") + return job + + def get_job(self, job_id: str) -> Job | None: + return self._jobs.get(job_id) + + def cancel_job(self, job_id: str) -> bool: + job = self._jobs.get(job_id) + if not job: + return False + if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): + return False + job.cancel() + return True + + def list_jobs( + self, + status: JobStatus | None = None, + job_type: JobType | None = None, + limit: int = 50, + ) -> list[dict]: + """List jobs, newest first.""" + jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True) + if status: + jobs = [j for j in jobs if j.status == status] + if job_type: + jobs = [j for j in jobs if j.job_type == job_type] + return [j.to_dict() for j in jobs[:limit]] + + def get_stats(self) -> dict: + """Get queue statistics.""" + by_status = {} + for j in self._jobs.values(): + by_status[j.status.value] = by_status.get(j.status.value, 0) + 1 + return { + "total": len(self._jobs), + "queued": self._queue.qsize(), + "by_status": by_status, + "workers": self._max_workers, + "active_workers": sum( + 1 for j in self._jobs.values() if j.status == JobStatus.RUNNING + ), + } + + def cleanup(self, max_age_seconds: float = 3600): + """Remove old completed/failed/cancelled jobs.""" + now = time.time() + to_remove = [ + jid for jid, j in self._jobs.items() + if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + and (now - j.created_at) > max_age_seconds + ] + for jid in to_remove: + del self._jobs[jid] + if to_remove: + logger.info(f"Cleaned up {len(to_remove)} old jobs") + + async def _worker(self, worker_id: int): + """Worker loop: pull jobs from queue and execute handlers.""" + logger.info(f"Worker {worker_id} started") + while self._started: + try: + job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + job = self._jobs.get(job_id) + if not job or job.is_cancelled: + continue + + handler = self._handlers.get(job.job_type) + if not handler: + job.status = JobStatus.FAILED + job.error = f"No handler for {job.job_type.value}" + job.completed_at = time.time() + logger.error(f"No handler for job type {job.job_type.value}") + continue + + job.status = JobStatus.RUNNING + job.started_at = time.time() + job.message = "Running..." + logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") + + try: + result = await handler(job) + if not job.is_cancelled: + job.status = JobStatus.COMPLETED + job.progress = 100.0 + job.result = result + job.message = "Completed" + job.completed_at = time.time() + logger.info( + f"Worker {worker_id}: completed {job.id} " + f"in {job.elapsed_ms}ms" + ) + except Exception as e: + if not job.is_cancelled: + job.status = JobStatus.FAILED + job.error = str(e) + job.message = f"Failed: {e}" + job.completed_at = time.time() + logger.error( + f"Worker {worker_id}: failed {job.id}: {e}", + exc_info=True, + ) + + +# Singleton + job handlers + +job_queue = JobQueue(max_workers=3) + + +async def _handle_triage(job: Job): + """Triage handler.""" + from app.services.triage import triage_dataset + dataset_id = job.params.get("dataset_id") + job.message = f"Triaging dataset {dataset_id}" + results = await triage_dataset(dataset_id) + return {"count": len(results) if results else 0} + + +async def _handle_host_profile(job: Job): + """Host profiling handler.""" + from app.services.host_profiler import profile_all_hosts, profile_host + hunt_id = job.params.get("hunt_id") + hostname = job.params.get("hostname") + if hostname: + job.message = f"Profiling host {hostname}" + await profile_host(hunt_id, hostname) + return {"hostname": hostname} + else: + job.message = f"Profiling all hosts in hunt {hunt_id}" + await profile_all_hosts(hunt_id) + return {"hunt_id": hunt_id} + + +async def _handle_report(job: Job): + """Report generation handler.""" + from app.services.report_generator import generate_report + hunt_id = job.params.get("hunt_id") + job.message = f"Generating report for hunt {hunt_id}" + report = await generate_report(hunt_id) + return {"report_id": report.id if report else None} + + +async def _handle_anomaly(job: Job): + """Anomaly detection handler.""" + from app.services.anomaly_detector import detect_anomalies + dataset_id = job.params.get("dataset_id") + k = job.params.get("k", 3) + threshold = job.params.get("threshold", 0.35) + job.message = f"Detecting anomalies in dataset {dataset_id}" + results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold) + return {"count": len(results) if results else 0} + + +async def _handle_query(job: Job): + """Data query handler (non-streaming).""" + from app.services.data_query import query_dataset + dataset_id = job.params.get("dataset_id") + question = job.params.get("question", "") + mode = job.params.get("mode", "quick") + job.message = f"Querying dataset {dataset_id}" + answer = await query_dataset(dataset_id, question, mode) + return {"answer": answer} + + +def register_all_handlers(): + """Register all job handlers.""" + job_queue.register_handler(JobType.TRIAGE, _handle_triage) + job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile) + job_queue.register_handler(JobType.REPORT, _handle_report) + job_queue.register_handler(JobType.ANOMALY, _handle_anomaly) + job_queue.register_handler(JobType.QUERY, _handle_query) \ No newline at end of file diff --git a/backend/app/services/load_balancer.py b/backend/app/services/load_balancer.py new file mode 100644 index 0000000..3ccc57c --- /dev/null +++ b/backend/app/services/load_balancer.py @@ -0,0 +1,193 @@ +"""Smart load balancer for Wile & Roadrunner LLM nodes. + +Tracks active jobs per node, health status, and routes new work +to the least-busy healthy node. Periodically pings both nodes +to maintain an up-to-date health map. +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from app.config import settings + +logger = logging.getLogger(__name__) + + +class NodeId(str, Enum): + WILE = "wile" + ROADRUNNER = "roadrunner" + + +class WorkloadTier(str, Enum): + """What kind of workload is this?""" + HEAVY = "heavy" # 70B models, deep analysis, reports + FAST = "fast" # 7-14B models, triage, quick queries + EMBEDDING = "embed" # bge-m3 embeddings + ANY = "any" + + +@dataclass +class NodeStatus: + node_id: NodeId + url: str + healthy: bool = True + last_check: float = 0.0 + active_jobs: int = 0 + total_completed: int = 0 + total_errors: int = 0 + avg_latency_ms: float = 0.0 + _latencies: list[float] = field(default_factory=list) + + def record_completion(self, latency_ms: float): + self.active_jobs = max(0, self.active_jobs - 1) + self.total_completed += 1 + self._latencies.append(latency_ms) + # Rolling average of last 50 + if len(self._latencies) > 50: + self._latencies = self._latencies[-50:] + self.avg_latency_ms = sum(self._latencies) / len(self._latencies) + + def record_error(self): + self.active_jobs = max(0, self.active_jobs - 1) + self.total_errors += 1 + + def record_start(self): + self.active_jobs += 1 + + +class LoadBalancer: + """Routes LLM work to the least-busy healthy node. + + Node capabilities: + - Wile: Heavy models (70B), code models (32B) + - Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision + """ + + # Which nodes can handle which tiers + TIER_NODES = { + WorkloadTier.HEAVY: [NodeId.WILE], + WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE], + WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER], + WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE], + } + + def __init__(self): + self._nodes: dict[NodeId, NodeStatus] = { + NodeId.WILE: NodeStatus( + node_id=NodeId.WILE, + url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}", + ), + NodeId.ROADRUNNER: NodeStatus( + node_id=NodeId.ROADRUNNER, + url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}", + ), + } + self._lock = asyncio.Lock() + self._health_task: Optional[asyncio.Task] = None + + async def start_health_loop(self, interval: float = 30.0): + """Start background health-check loop.""" + if self._health_task and not self._health_task.done(): + return + self._health_task = asyncio.create_task(self._health_loop(interval)) + logger.info("Load balancer health loop started (%.0fs interval)", interval) + + async def stop_health_loop(self): + if self._health_task: + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + pass + self._health_task = None + + async def _health_loop(self, interval: float): + while True: + try: + await self.check_health() + except Exception as e: + logger.warning(f"Health check error: {e}") + await asyncio.sleep(interval) + + async def check_health(self): + """Ping both nodes and update status.""" + import httpx + async with httpx.AsyncClient(timeout=5) as client: + for nid, status in self._nodes.items(): + try: + resp = await client.get(f"{status.url}/api/tags") + status.healthy = resp.status_code == 200 + except Exception: + status.healthy = False + status.last_check = time.time() + logger.debug( + f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} " + f"(active={status.active_jobs})" + ) + + def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId: + """Select the best node for a workload tier. + + Strategy: among healthy nodes that support the tier, + pick the one with fewest active jobs. + Falls back to any node if none healthy. + """ + candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE]) + + # Filter to healthy candidates + healthy = [ + nid for nid in candidates + if self._nodes[nid].healthy + ] + + if not healthy: + logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate") + healthy = candidates + + # Pick least busy + best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs) + return best + + def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId: + """Select node and mark a job started.""" + node = self.select_node(tier) + self._nodes[node].record_start() + logger.info( + f"LB: dispatched {tier.value} -> {node.value} " + f"(active={self._nodes[node].active_jobs})" + ) + return node + + def release(self, node: NodeId, latency_ms: float = 0, error: bool = False): + """Mark a job completed on a node.""" + status = self._nodes.get(node) + if not status: + return + if error: + status.record_error() + else: + status.record_completion(latency_ms) + + def get_status(self) -> dict: + """Get current load balancer status.""" + return { + nid.value: { + "healthy": s.healthy, + "active_jobs": s.active_jobs, + "total_completed": s.total_completed, + "total_errors": s.total_errors, + "avg_latency_ms": round(s.avg_latency_ms, 1), + "last_check": s.last_check, + } + for nid, s in self._nodes.items() + } + + +# Singleton +lb = LoadBalancer() \ No newline at end of file diff --git a/backend/app/services/report_generator.py b/backend/app/services/report_generator.py new file mode 100644 index 0000000..ef0d57e --- /dev/null +++ b/backend/app/services/report_generator.py @@ -0,0 +1,198 @@ +"""Report generator - debate-powered hunt report generation using Wile + Roadrunner.""" + +from __future__ import annotations + +import json +import logging +import time + +import httpx +from sqlalchemy import select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import ( + Dataset, HostProfile, HuntReport, TriageResult, +) +from app.services.triage import _parse_llm_response + +logger = logging.getLogger(__name__) + +WILE_URL = f"{settings.wile_url}/api/generate" +ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" +HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL +FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" + + +async def _llm_call(url: str, model: str, system: str, prompt: str, timeout: float = 300.0) -> str: + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post( + url, + json={ + "model": model, + "prompt": prompt, + "system": system, + "stream": False, + "options": {"temperature": 0.3, "num_predict": 8192}, + }, + ) + resp.raise_for_status() + return resp.json().get("response", "") + + +async def _gather_evidence(db, hunt_id: str) -> dict: + ds_result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = ds_result.scalars().all() + + dataset_summary = [] + all_triage = [] + for ds in datasets: + ds_info = { + "name": ds.name, + "artifact_type": getattr(ds, "artifact_type", "Unknown"), + "row_count": ds.row_count or 0, + } + dataset_summary.append(ds_info) + + triage_result = await db.execute( + select(TriageResult) + .where(TriageResult.dataset_id == ds.id) + .where(TriageResult.risk_score >= 3.0) + .order_by(TriageResult.risk_score.desc()) + .limit(15) + ) + for t in triage_result.scalars().all(): + all_triage.append({ + "dataset": ds.name, + "artifact_type": ds_info["artifact_type"], + "rows": f"{t.row_start}-{t.row_end}", + "risk_score": t.risk_score, + "verdict": t.verdict, + "findings": t.findings[:5] if t.findings else [], + "indicators": t.suspicious_indicators[:5] if t.suspicious_indicators else [], + "mitre": t.mitre_techniques or [], + }) + + profile_result = await db.execute( + select(HostProfile) + .where(HostProfile.hunt_id == hunt_id) + .order_by(HostProfile.risk_score.desc()) + ) + profiles = profile_result.scalars().all() + host_summaries = [] + for p in profiles: + host_summaries.append({ + "hostname": p.hostname, + "risk_score": p.risk_score, + "risk_level": p.risk_level, + "findings": p.suspicious_findings[:5] if p.suspicious_findings else [], + "mitre": p.mitre_techniques or [], + "timeline": (p.timeline_summary or "")[:300], + }) + + return { + "datasets": dataset_summary, + "triage_findings": all_triage[:30], + "host_profiles": host_summaries, + "total_datasets": len(datasets), + "total_rows": sum(d["row_count"] for d in dataset_summary), + "high_risk_hosts": len([h for h in host_summaries if h["risk_score"] >= 7.0]), + } + + +async def generate_report(hunt_id: str) -> None: + logger.info("Generating report for hunt %s", hunt_id) + start = time.monotonic() + + async with async_session() as db: + report = HuntReport( + hunt_id=hunt_id, + status="generating", + models_used=[HEAVY_MODEL, FAST_MODEL], + ) + db.add(report) + await db.commit() + await db.refresh(report) + report_id = report.id + + try: + evidence = await _gather_evidence(db, hunt_id) + evidence_text = json.dumps(evidence, indent=1, default=str)[:12000] + + # Phase 1: Wile initial analysis + logger.info("Report phase 1: Wile initial analysis") + phase1 = await _llm_call( + WILE_URL, HEAVY_MODEL, + system=( + "You are a senior threat intelligence analyst writing a hunt report.\n" + "Analyze all evidence and produce a structured threat assessment.\n" + "Include: executive summary, detailed findings per host, MITRE mapping,\n" + "IOC table, risk rankings, and actionable recommendations.\n" + "Use markdown formatting. Be thorough and specific." + ), + prompt=f"Hunt evidence:\n{evidence_text}\n\nProduce your initial threat assessment.", + ) + + # Phase 2: Roadrunner critical review + logger.info("Report phase 2: Roadrunner critical review") + phase2 = await _llm_call( + ROADRUNNER_URL, FAST_MODEL, + system=( + "You are a critical reviewer of threat hunt reports.\n" + "Review the initial assessment and identify:\n" + "- Missing correlations or overlooked indicators\n" + "- False positive risks or overblown findings\n" + "- Additional MITRE techniques that should be mapped\n" + "- Gaps in recommendations\n" + "Be specific and constructive. Respond in markdown." + ), + prompt=f"Evidence:\n{evidence_text[:4000]}\n\nInitial Assessment:\n{phase1[:6000]}\n\nProvide your critical review.", + timeout=120.0, + ) + + # Phase 3: Wile final synthesis + logger.info("Report phase 3: Wile final synthesis") + synthesis_prompt = ( + f"Original evidence:\n{evidence_text[:6000]}\n\n" + f"Initial assessment:\n{phase1[:5000]}\n\n" + f"Critical review:\n{phase2[:3000]}\n\n" + "Produce the FINAL hunt report incorporating the review feedback.\n" + "Return JSON with these keys:\n" + "- executive_summary: 2-3 paragraph executive summary\n" + "- findings: list of {title, severity, description, evidence, mitre_ids}\n" + "- recommendations: list of {priority, action, rationale}\n" + "- mitre_mapping: dict of technique_id -> {name, description, evidence}\n" + "- ioc_table: list of {type, value, context, confidence}\n" + "- host_risk_summary: list of {hostname, risk_score, risk_level, key_findings}\n" + "Respond with valid JSON only." + ) + phase3_text = await _llm_call( + WILE_URL, HEAVY_MODEL, + system="You are producing the final, definitive threat hunt report. Incorporate all feedback. Respond with valid JSON only.", + prompt=synthesis_prompt, + ) + + parsed = _parse_llm_response(phase3_text) + elapsed_ms = int((time.monotonic() - start) * 1000) + + full_report = f"# Threat Hunt Report\n\n{phase1}\n\n---\n## Review Notes\n{phase2}\n\n---\n## Final Synthesis\n{phase3_text}" + + report.status = "complete" + report.exec_summary = parsed.get("executive_summary", phase1[:2000]) + report.full_report = full_report + report.findings = parsed.get("findings", []) + report.recommendations = parsed.get("recommendations", []) + report.mitre_mapping = parsed.get("mitre_mapping", {}) + report.ioc_table = parsed.get("ioc_table", []) + report.host_risk_summary = parsed.get("host_risk_summary", []) + report.generation_time_ms = elapsed_ms + await db.commit() + + logger.info("Report %s complete in %dms", report_id, elapsed_ms) + + except Exception as e: + logger.error("Report generation failed for hunt %s: %s", hunt_id, e) + report.status = "error" + report.exec_summary = f"Report generation failed: {e}" + report.generation_time_ms = int((time.monotonic() - start) * 1000) + await db.commit() \ No newline at end of file diff --git a/backend/app/services/triage.py b/backend/app/services/triage.py new file mode 100644 index 0000000..09be225 --- /dev/null +++ b/backend/app/services/triage.py @@ -0,0 +1,170 @@ +"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner.""" + +from __future__ import annotations + +import json +import logging +import re + +import httpx +from sqlalchemy import func, select + +from app.config import settings +from app.db.engine import async_session +from app.db.models import Dataset, DatasetRow, TriageResult + +logger = logging.getLogger(__name__) + +DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" +ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" + +ARTIFACT_FOCUS = { + "Windows.System.Pslist": "Look for: suspicious parent-child, LOLBins, unsigned, injection indicators, abnormal paths.", + "Windows.Network.Netstat": "Look for: C2 beaconing, unusual ports, connections to rare IPs, non-browser high-port listeners.", + "Windows.System.Services": "Look for: services in temp dirs, misspelled names, unsigned ServiceDll, unusual start modes.", + "Windows.Forensics.Prefetch": "Look for: recon tools, lateral movement tools, rarely-run executables with high run counts.", + "Windows.EventLogs.EvtxHunter": "Look for: logon type 10/3 anomalies, service installs, PowerShell script blocks, clearing.", + "Windows.Sys.Autoruns": "Look for: recently added entries, entries in temp/user dirs, encoded commands, suspicious DLLs.", + "Windows.Registry.Finder": "Look for: run keys, image file execution options, hidden services, encoded payloads.", + "Windows.Search.FileFinder": "Look for: files in unusual locations, recently modified system files, known tool names.", +} + + +def _parse_llm_response(text: str) -> dict: + text = text.strip() + fence = re.search(r"`(?:json)?\s*\n?(.*?)\n?\s*`", text, re.DOTALL) + if fence: + text = fence.group(1).strip() + try: + return json.loads(text) + except json.JSONDecodeError: + brace = text.find("{") + bracket = text.rfind("}") + if brace != -1 and bracket != -1 and bracket > brace: + try: + return json.loads(text[brace : bracket + 1]) + except json.JSONDecodeError: + pass + return {"raw_response": text[:3000]} + + +async def triage_dataset(dataset_id: str) -> None: + logger.info("Starting triage for dataset %s", dataset_id) + + async with async_session() as db: + ds_result = await db.execute( + select(Dataset).where(Dataset.id == dataset_id) + ) + dataset = ds_result.scalar_one_or_none() + if not dataset: + logger.error("Dataset %s not found", dataset_id) + return + + artifact_type = getattr(dataset, "artifact_type", None) or "Unknown" + focus = ARTIFACT_FOCUS.get(artifact_type, "Analyze for any suspicious indicators.") + + count_result = await db.execute( + select(func.count()).where(DatasetRow.dataset_id == dataset_id) + ) + total_rows = count_result.scalar() or 0 + + batch_size = settings.TRIAGE_BATCH_SIZE + suspicious_count = 0 + offset = 0 + + while offset < total_rows: + if suspicious_count >= settings.TRIAGE_MAX_SUSPICIOUS_ROWS: + logger.info("Reached suspicious row cap for dataset %s", dataset_id) + break + + rows_result = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == dataset_id) + .order_by(DatasetRow.row_number) + .offset(offset) + .limit(batch_size) + ) + rows = rows_result.scalars().all() + if not rows: + break + + batch_data = [] + for r in rows: + data = r.normalized_data or r.data + compact = {k: str(v)[:200] for k, v in data.items() if v} + batch_data.append(compact) + + system_prompt = f"""You are a cybersecurity triage analyst. Analyze this batch of {artifact_type} forensic data. +{focus} + +Return JSON with: +- risk_score: 0.0 (benign) to 10.0 (critical threat) +- verdict: "clean", "suspicious", "malicious", or "inconclusive" +- findings: list of key observations +- suspicious_indicators: list of specific IOCs or anomalies +- mitre_techniques: list of MITRE ATT&CK IDs if applicable + +Be precise. Only flag genuinely suspicious items. Respond with valid JSON only.""" + + prompt = f"Rows {offset+1}-{offset+len(rows)} of {total_rows}:\n{json.dumps(batch_data, default=str)[:6000]}" + + try: + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post( + ROADRUNNER_URL, + json={ + "model": DEFAULT_FAST_MODEL, + "prompt": prompt, + "system": system_prompt, + "stream": False, + "options": {"temperature": 0.2, "num_predict": 2048}, + }, + ) + resp.raise_for_status() + result = resp.json() + llm_text = result.get("response", "") + + parsed = _parse_llm_response(llm_text) + risk = float(parsed.get("risk_score", 0.0)) + + triage = TriageResult( + dataset_id=dataset_id, + row_start=offset, + row_end=offset + len(rows) - 1, + risk_score=risk, + verdict=parsed.get("verdict", "inconclusive"), + findings=parsed.get("findings", []), + suspicious_indicators=parsed.get("suspicious_indicators", []), + mitre_techniques=parsed.get("mitre_techniques", []), + model_used=DEFAULT_FAST_MODEL, + node_used="roadrunner", + ) + db.add(triage) + await db.commit() + + if risk >= settings.TRIAGE_ESCALATION_THRESHOLD: + suspicious_count += len(rows) + + logger.debug( + "Triage batch %d-%d: risk=%.1f verdict=%s", + offset, offset + len(rows) - 1, risk, triage.verdict, + ) + + except Exception as e: + logger.error("Triage batch %d failed: %s", offset, e) + triage = TriageResult( + dataset_id=dataset_id, + row_start=offset, + row_end=offset + len(rows) - 1, + risk_score=0.0, + verdict="error", + findings=[f"Error: {e}"], + model_used=DEFAULT_FAST_MODEL, + node_used="roadrunner", + ) + db.add(triage) + await db.commit() + + offset += batch_size + + logger.info("Triage complete for dataset %s", dataset_id) \ No newline at end of file diff --git a/backend/run.py b/backend/run.py index c8fa0b6..1bb8002 100644 --- a/backend/run.py +++ b/backend/run.py @@ -13,5 +13,5 @@ if __name__ == "__main__": "app.main:app", host="0.0.0.0", port=8000, - reload=True, - ) + reload=False, + ) \ No newline at end of file diff --git a/backend/scan_cols.py b/backend/scan_cols.py new file mode 100644 index 0000000..7f6248f --- /dev/null +++ b/backend/scan_cols.py @@ -0,0 +1,8 @@ +import json, urllib.request +url = "http://localhost:8000/api/datasets?skip=0&limit=20&hunt_id=fd8ba3fb45de4d65bea072f73d80544d" +data = json.loads(urllib.request.urlopen(url).read()) +for d in data["datasets"]: + ioc = list((d["ioc_columns"] or {}).items()) + norm = d.get("normalized_columns") or {} + hc = {k: v for k, v in norm.items() if v in ("hostname", "fqdn", "username", "src_ip", "dst_ip", "ip_address", "os")} + print(d["name"], "|", d["row_count"], "|", ioc, "|", hc) \ No newline at end of file diff --git a/backend/scan_rows.py b/backend/scan_rows.py new file mode 100644 index 0000000..6843762 --- /dev/null +++ b/backend/scan_rows.py @@ -0,0 +1,23 @@ +import json, urllib.request + +def get(path): + return json.loads(urllib.request.urlopen("http://localhost:8000" + path).read()) + +# Check ip_to_hostname_mapping +ds_list = get("/api/datasets?skip=0&limit=20&hunt_id=fd8ba3fb45de4d65bea072f73d80544d") +for d in ds_list["datasets"]: + if d["name"] == "ip_to_hostname_mapping": + rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=5") + print("=== ip_to_hostname_mapping ===") + for r in rows["rows"]: + print(r) + if d["name"] == "Netstat": + rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=3") + print("=== Netstat ===") + for r in rows["rows"]: + print(r) + if d["name"] == "netstat_enrich2": + rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=3") + print("=== netstat_enrich2 ===") + for r in rows["rows"]: + print(r) \ No newline at end of file diff --git a/backend/threathunt.db-shm b/backend/threathunt.db-shm new file mode 100644 index 0000000..fe9ac28 Binary files /dev/null and b/backend/threathunt.db-shm differ diff --git a/backend/threathunt.db-wal b/backend/threathunt.db-wal new file mode 100644 index 0000000..e69de29 diff --git a/frontend/nginx.conf b/frontend/nginx.conf index 959445f..432f927 100644 --- a/frontend/nginx.conf +++ b/frontend/nginx.conf @@ -15,10 +15,10 @@ server { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; - proxy_read_timeout 120s; + proxy_read_timeout 300s; } - # SPA fallback — serve index.html for all non-file routes + # SPA fallback serve index.html for all non-file routes location / { try_files $uri $uri/ /index.html; } @@ -28,4 +28,4 @@ server { expires 1y; add_header Cache-Control "public, immutable"; } -} +} \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 1574641..5463bd9 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -75,7 +75,6 @@ "resolved": "https://registry.npmjs.org/@babel/core/-/core-7.29.0.tgz", "integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==", "license": "MIT", - "peer": true, "dependencies": { "@babel/code-frame": "^7.29.0", "@babel/generator": "^7.29.0", @@ -731,7 +730,6 @@ "resolved": "https://registry.npmjs.org/@babel/plugin-syntax-flow/-/plugin-syntax-flow-7.28.6.tgz", "integrity": "sha512-D+OrJumc9McXNEBI/JmFnc/0uCM2/Y3PEBG3gfV3QIYkKv5pvnpzFrl1kYCrcHJP8nOeFB/SHi1IHz29pNGuew==", "license": "MIT", - "peer": true, "dependencies": { "@babel/helper-plugin-utils": "^7.28.6" }, @@ -1615,7 +1613,6 @@ "resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.28.6.tgz", "integrity": "sha512-61bxqhiRfAACulXSLd/GxqmAedUSrRZIu/cbaT18T1CetkTmtDN15it7i80ru4DVqRK1WMxQhXs+Lf9kajm5Ow==", "license": "MIT", - "peer": true, "dependencies": { "@babel/helper-annotate-as-pure": "^7.27.3", "@babel/helper-module-imports": "^7.28.6", @@ -2457,7 +2454,6 @@ "resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.14.0.tgz", "integrity": "sha512-O000MLDBDdk/EohJPFUqvnp4qnHeYkVP5B0xEG0D/L7cOKP9kefu2DXn8dj74cQfsEzUqh+sr1RzFqiL1o+PpA==", "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.18.3", "@emotion/babel-plugin": "^11.13.5", @@ -2501,7 +2497,6 @@ "resolved": "https://registry.npmjs.org/@emotion/styled/-/styled-11.14.1.tgz", "integrity": "sha512-qEEJt42DuToa3gurlH4Qqc1kVpNq8wO8cJtDzU46TjlzWjDlsVyevtYCRijVq3SrHsROS+gVQ8Fnea108GnKzw==", "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.18.3", "@emotion/babel-plugin": "^11.13.5", @@ -3083,7 +3078,6 @@ "resolved": "https://registry.npmjs.org/@mui/material/-/material-7.3.8.tgz", "integrity": "sha512-QKd1RhDXE1hf2sQDNayA9ic9jGkEgvZOf0tTkJxlBPG8ns8aS4rS8WwYURw2x5y3739p0HauUXX9WbH7UufFLw==", "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.28.6", "@mui/core-downloads-tracker": "^7.3.8", @@ -3194,7 +3188,6 @@ "resolved": "https://registry.npmjs.org/@mui/system/-/system-7.3.8.tgz", "integrity": "sha512-hoFRj4Zw2Km8DPWZp/nKG+ao5Jw5LSk2m/e4EGc6M3RRwXKEkMSG4TgtfVJg7dS2homRwtdXSMW+iRO0ZJ4+IA==", "license": "MIT", - "peer": true, "dependencies": { "@babel/runtime": "^7.28.6", "@mui/private-theming": "^7.3.8", @@ -4296,7 +4289,6 @@ "resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.28.tgz", "integrity": "sha512-z9VXpC7MWrhfWipitjNdgCauoMLRdIILQsAEV+ZesIzBq/oUlxk0m3ApZuMFCXdnS4U7KrI+l3WRUEGQ8K1QKw==", "license": "MIT", - "peer": true, "dependencies": { "@types/prop-types": "*", "csstype": "^3.2.2" @@ -4464,7 +4456,6 @@ "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.62.0.tgz", "integrity": "sha512-TiZzBSJja/LbhNPvk6yc0JrX9XqhQ0hdh6M2svYfsHGejaKFIAGd9MQ+ERIMzLGlN/kZoYIgdxFV0PuljTKXag==", "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/regexpp": "^4.4.0", "@typescript-eslint/scope-manager": "5.62.0", @@ -4518,7 +4509,6 @@ "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.62.0.tgz", "integrity": "sha512-VlJEV0fOQ7BExOsHYAGrgbEiZoi8D+Bl2+f6V2RrXerRSylnp+ZBHmPvaIa8cz0Ajx7WO7Z5RqfgYg7ED1nRhA==", "license": "BSD-2-Clause", - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "5.62.0", "@typescript-eslint/types": "5.62.0", @@ -4888,7 +4878,6 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.16.0.tgz", "integrity": "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==", "license": "MIT", - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -4987,7 +4976,6 @@ "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", "license": "MIT", - "peer": true, "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", @@ -5910,7 +5898,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "baseline-browser-mapping": "^2.9.0", "caniuse-lite": "^1.0.30001759", @@ -7037,8 +7024,7 @@ "version": "3.2.3", "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz", "integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/cytoscape": { "version": "3.33.1", @@ -8069,7 +8055,6 @@ "integrity": "sha512-ypowyDxpVSYpkXr9WPv2PAZCtNip1Mv5KTW0SCurXv/9iOpcrH9PaqUElksqEB6pChqHGDRCFTyrZlGhnLNGiA==", "deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.", "license": "MIT", - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.6.1", @@ -10977,7 +10962,6 @@ "resolved": "https://registry.npmjs.org/jest/-/jest-27.5.1.tgz", "integrity": "sha512-Yn0mADZB89zTtjkPJEXwrac3LHudkQMR+Paqa8uxJHCBr9agxztUifWCyiYrjhMPBoUVBjyny0I7XH6ozDr7QQ==", "license": "MIT", - "peer": true, "dependencies": { "@jest/core": "^27.5.1", "import-local": "^3.0.2", @@ -11875,7 +11859,6 @@ "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz", "integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==", "license": "MIT", - "peer": true, "bin": { "jiti": "bin/jiti.js" } @@ -14143,7 +14126,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "nanoid": "^3.3.11", "picocolors": "^1.1.1", @@ -15278,7 +15260,6 @@ "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz", "integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==", "license": "MIT", - "peer": true, "dependencies": { "cssesc": "^3.0.0", "util-deprecate": "^1.0.2" @@ -15654,7 +15635,6 @@ "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", "license": "MIT", - "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -15789,7 +15769,6 @@ "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", "license": "MIT", - "peer": true, "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" @@ -15867,7 +15846,6 @@ "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz", "integrity": "sha512-F27qZr8uUqwhWZboondsPx8tnC3Ct3SxZA3V5WyEvujRyyNv0VYPhoBg1gZ8/MV5tubQp76Trw8lTv9hzRBa+A==", "license": "MIT", - "peer": true, "engines": { "node": ">=0.10.0" } @@ -16492,7 +16470,6 @@ "resolved": "https://registry.npmjs.org/rollup/-/rollup-2.79.2.tgz", "integrity": "sha512-fS6iqSPZDs3dr/y7Od6y5nha8dW1YnbgtsyotCVvoFGKbERG++CVRFv1meyGDE1SNItQA8BrnCw7ScdAhRJ3XQ==", "license": "MIT", - "peer": true, "bin": { "rollup": "dist/bin/rollup" }, @@ -16738,7 +16715,6 @@ "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", "integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==", "license": "MIT", - "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", @@ -18152,7 +18128,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=12" }, @@ -18450,7 +18425,6 @@ "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -18959,7 +18933,6 @@ "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.2.tgz", "integrity": "sha512-dRXm0a2qcHPUBEzVk8uph0xWSjV/xZxenQQbLwnwP7caQCYpqG1qddwlyEkIDkYn0K8tvmcrZ+bOrzoQ3HxCDw==", "license": "MIT", - "peer": true, "dependencies": { "@types/eslint-scope": "^3.7.7", "@types/estree": "^1.0.8", @@ -19445,7 +19418,6 @@ "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz", "integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==", "license": "MIT", - "peer": true, "dependencies": { "fast-deep-equal": "^3.1.3", "fast-uri": "^3.0.1", diff --git a/frontend/src/components/AnalysisDashboard.tsx b/frontend/src/components/AnalysisDashboard.tsx new file mode 100644 index 0000000..f79253d --- /dev/null +++ b/frontend/src/components/AnalysisDashboard.tsx @@ -0,0 +1,818 @@ +/** + * AnalysisDashboard -- 6-tab view covering the full AI analysis pipeline: + * 1. Triage results + * 2. Host profiles + * 3. Reports + * 4. Anomalies + * 5. Ask Data (natural language query with SSE streaming) -- Phase 9 + * 6. Jobs & Load Balancer status -- Phase 10 + */ + +import React, { useState, useEffect, useCallback, useRef } from 'react'; +import { + Box, Typography, Tabs, Tab, Paper, Button, Chip, Stack, CircularProgress, + Table, TableBody, TableCell, TableContainer, TableHead, TableRow, + Accordion, AccordionSummary, AccordionDetails, Alert, Select, MenuItem, + FormControl, InputLabel, LinearProgress, Tooltip, IconButton, Divider, + Card, CardContent, CardActions, Grid, TextField, ToggleButton, + ToggleButtonGroup, +} from '@mui/material'; +import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; +import PlayArrowIcon from '@mui/icons-material/PlayArrow'; +import RefreshIcon from '@mui/icons-material/Refresh'; +import AssessmentIcon from '@mui/icons-material/Assessment'; +import SecurityIcon from '@mui/icons-material/Security'; +import PersonIcon from '@mui/icons-material/Person'; +import WarningAmberIcon from '@mui/icons-material/WarningAmber'; +import ShieldIcon from '@mui/icons-material/Shield'; +import BubbleChartIcon from '@mui/icons-material/BubbleChart'; +import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer'; +import WorkIcon from '@mui/icons-material/Work'; +import SendIcon from '@mui/icons-material/Send'; +import StopIcon from '@mui/icons-material/Stop'; +import DeleteIcon from '@mui/icons-material/Delete'; +import CheckCircleIcon from '@mui/icons-material/CheckCircle'; +import ErrorIcon from '@mui/icons-material/Error'; +import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; +import CancelIcon from '@mui/icons-material/Cancel'; +import { useSnackbar } from 'notistack'; +import { + analysis, hunts, datasets, + type Hunt, type DatasetSummary, + type TriageResultData, type HostProfileData, type HuntReportData, + type AnomalyResultData, type JobData, type JobStats, type LBNodeStatus, +} from '../api/client'; + +/* helpers */ + +function riskColor(score: number): 'error' | 'warning' | 'info' | 'success' | 'default' { + if (score >= 8) return 'error'; + if (score >= 5) return 'warning'; + if (score >= 2) return 'info'; + return 'success'; +} + +function riskLabel(level: string): 'error' | 'warning' | 'info' | 'success' | 'default' { + if (level === 'critical' || level === 'high') return 'error'; + if (level === 'medium') return 'warning'; + if (level === 'low') return 'success'; + return 'default'; +} + +function fmtMs(ms: number): string { + if (ms < 1000) return `${ms}ms`; + return `${(ms / 1000).toFixed(1)}s`; +} + +function fmtTime(ts: number): string { + if (!ts) return '--'; + return new Date(ts * 1000).toLocaleTimeString(); +} + +const statusIcon = (s: string) => { + switch (s) { + case 'completed': return ; + case 'failed': return ; + case 'running': return ; + case 'queued': return ; + case 'cancelled': return ; + default: return null; + } +}; + +/* TabPanel */ + +function TabPanel({ children, value, index }: { children: React.ReactNode; value: number; index: number }) { + return value === index ? {children} : null; +} + +/* Main component */ + +export default function AnalysisDashboard() { + const { enqueueSnackbar } = useSnackbar(); + const [tab, setTab] = useState(0); + + // Selectors + const [huntList, setHuntList] = useState([]); + const [dsList, setDsList] = useState([]); + const [huntId, setHuntId] = useState(''); + const [dsId, setDsId] = useState(''); + + // Data tabs 0-3 + const [triageResults, setTriageResults] = useState([]); + const [profiles, setProfiles] = useState([]); + const [reports, setReports] = useState([]); + const [anomalies, setAnomalies] = useState([]); + + // Loading states + const [loadingTriage, setLoadingTriage] = useState(false); + const [loadingProfiles, setLoadingProfiles] = useState(false); + const [loadingReports, setLoadingReports] = useState(false); + const [loadingAnomalies, setLoadingAnomalies] = useState(false); + const [triggering, setTriggering] = useState(false); + + // Phase 9: Ask Data + const [queryText, setQueryText] = useState(''); + const [queryMode, setQueryMode] = useState('quick'); + const [queryAnswer, setQueryAnswer] = useState(''); + const [queryStreaming, setQueryStreaming] = useState(false); + const [queryMeta, setQueryMeta] = useState | null>(null); + const [queryDone, setQueryDone] = useState | null>(null); + const abortRef = useRef(null); + const answerRef = useRef(null); + + // Phase 10: Jobs + const [jobs, setJobs] = useState([]); + const [jobStats, setJobStats] = useState(null); + const [lbStatus, setLbStatus] = useState | null>(null); + const [loadingJobs, setLoadingJobs] = useState(false); + + // Load hunts and datasets + useEffect(() => { + hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {}); + datasets.list(0, 200).then(r => setDsList(r.datasets)).catch(() => {}); + }, []); + + useEffect(() => { + if (huntList.length > 0 && !huntId) setHuntId(huntList[0].id); + }, [huntList, huntId]); + + useEffect(() => { + if (dsList.length > 0 && !dsId) setDsId(dsList[0].id); + }, [dsList, dsId]); + + /* Fetch triage results */ + const fetchTriage = useCallback(async () => { + if (!dsId) return; + setLoadingTriage(true); + try { + const data = await analysis.triageResults(dsId); + setTriageResults(data); + } catch (e: any) { + enqueueSnackbar(`Triage fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingTriage(false); } + }, [dsId, enqueueSnackbar]); + + const fetchProfiles = useCallback(async () => { + if (!huntId) return; + setLoadingProfiles(true); + try { + const data = await analysis.hostProfiles(huntId); + setProfiles(data); + } catch (e: any) { + enqueueSnackbar(`Profiles fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingProfiles(false); } + }, [huntId, enqueueSnackbar]); + + const fetchReports = useCallback(async () => { + if (!huntId) return; + setLoadingReports(true); + try { + const data = await analysis.listReports(huntId); + setReports(data); + } catch (e: any) { + enqueueSnackbar(`Reports fetch failed: ${e.message}`, { variant: 'error' }); + } finally { setLoadingReports(false); } + }, [huntId, enqueueSnackbar]); + + const fetchAnomalies = useCallback(async () => { + if (!dsId) return; + setLoadingAnomalies(true); + try { + const data = await analysis.anomalies(dsId); + setAnomalies(data); + } catch (e: any) { + enqueueSnackbar('Anomaly fetch failed: ' + e.message, { variant: 'error' }); + } finally { setLoadingAnomalies(false); } + }, [dsId, enqueueSnackbar]); + + const fetchJobs = useCallback(async () => { + setLoadingJobs(true); + try { + const data = await analysis.listJobs(); + setJobs(data.jobs); + setJobStats(data.stats); + } catch (e: any) { + enqueueSnackbar('Jobs fetch failed: ' + e.message, { variant: 'error' }); + } finally { setLoadingJobs(false); } + }, [enqueueSnackbar]); + + const fetchLbStatus = useCallback(async () => { + try { + const data = await analysis.lbStatus(); + setLbStatus(data); + } catch {} + }, []); + + // Load data when selectors change + useEffect(() => { if (dsId) fetchTriage(); }, [dsId, fetchTriage]); + useEffect(() => { if (huntId) { fetchProfiles(); fetchReports(); } }, [huntId, fetchProfiles, fetchReports]); + + // Auto-refresh jobs when on jobs tab + useEffect(() => { + if (tab === 5) { + fetchJobs(); + fetchLbStatus(); + const iv = setInterval(() => { fetchJobs(); fetchLbStatus(); }, 5000); + return () => clearInterval(iv); + } + }, [tab, fetchJobs, fetchLbStatus]); + + /* Trigger actions */ + const doTriggerTriage = useCallback(async () => { + if (!dsId) return; + setTriggering(true); + try { + await analysis.triggerTriage(dsId); + enqueueSnackbar('Triage started', { variant: 'info' }); + setTimeout(fetchTriage, 5000); + } catch (e: any) { + enqueueSnackbar(`Triage trigger failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [dsId, enqueueSnackbar, fetchTriage]); + + const doTriggerProfiles = useCallback(async () => { + if (!huntId) return; + setTriggering(true); + try { + await analysis.triggerAllProfiles(huntId); + enqueueSnackbar('Host profiling started', { variant: 'info' }); + setTimeout(fetchProfiles, 10000); + } catch (e: any) { + enqueueSnackbar(`Profile trigger failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [huntId, enqueueSnackbar, fetchProfiles]); + + const doGenerateReport = useCallback(async () => { + if (!huntId) return; + setTriggering(true); + try { + await analysis.generateReport(huntId); + enqueueSnackbar('Report generation started', { variant: 'info' }); + setTimeout(fetchReports, 15000); + } catch (e: any) { + enqueueSnackbar(`Report generation failed: ${e.message}`, { variant: 'error' }); + } finally { setTriggering(false); } + }, [huntId, enqueueSnackbar, fetchReports]); + + const doTriggerAnomalies = useCallback(async () => { + if (!dsId) return; + setTriggering(true); + try { + await analysis.triggerAnomalyDetection(dsId); + enqueueSnackbar('Anomaly detection started', { variant: 'info' }); + setTimeout(fetchAnomalies, 20000); + } catch (e: any) { + enqueueSnackbar('Anomaly trigger failed: ' + e.message, { variant: 'error' }); + } finally { setTriggering(false); } + }, [dsId, enqueueSnackbar, fetchAnomalies]); + + /* Phase 9: Streaming data query */ + const doQuery = useCallback(async () => { + if (!dsId || !queryText.trim()) return; + setQueryStreaming(true); + setQueryAnswer(''); + setQueryMeta(null); + setQueryDone(null); + + const controller = new AbortController(); + abortRef.current = controller; + + try { + const resp = await analysis.queryStream(dsId, queryText.trim(), queryMode); + if (!resp.body) throw new Error('No response body'); + + const reader = resp.body.getReader(); + const decoder = new TextDecoder(); + let buf = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + + const lines = buf.split('\n'); + buf = lines.pop() || ''; + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + try { + const evt = JSON.parse(line.slice(6)); + switch (evt.type) { + case 'token': + setQueryAnswer(prev => prev + evt.content); + if (answerRef.current) { + answerRef.current.scrollTop = answerRef.current.scrollHeight; + } + break; + case 'metadata': + setQueryMeta(evt.dataset); + break; + case 'done': + setQueryDone(evt); + break; + case 'error': + enqueueSnackbar(`Query error: ${evt.message}`, { variant: 'error' }); + break; + } + } catch {} + } + } + } catch (e: any) { + if (e.name !== 'AbortError') { + enqueueSnackbar('Query failed: ' + e.message, { variant: 'error' }); + } + } finally { + setQueryStreaming(false); + abortRef.current = null; + } + }, [dsId, queryText, queryMode, enqueueSnackbar]); + + const stopQuery = useCallback(() => { + if (abortRef.current) { + abortRef.current.abort(); + setQueryStreaming(false); + } + }, []); + + /* Phase 10: Cancel job */ + const doCancelJob = useCallback(async (jobId: string) => { + try { + await analysis.cancelJob(jobId); + enqueueSnackbar('Job cancelled', { variant: 'info' }); + fetchJobs(); + } catch (e: any) { + enqueueSnackbar('Cancel failed: ' + e.message, { variant: 'error' }); + } + }, [enqueueSnackbar, fetchJobs]); + + return ( + + + + AI Analysis + {triggering && } + + + {/* Selectors */} + + + + Hunt + + + + Dataset + + + + + + {/* Tabs */} + setTab(v)} variant="scrollable" scrollButtons="auto" sx={{ mb: 1 }}> + } iconPosition="start" label={`Triage (${triageResults.length})`} /> + } iconPosition="start" label={`Host Profiles (${profiles.length})`} /> + } iconPosition="start" label={`Reports (${reports.length})`} /> + } iconPosition="start" label={`Anomalies (${anomalies.filter(a => a.is_outlier).length})`} /> + } iconPosition="start" label="Ask Data" /> + } iconPosition="start" label={`Jobs${jobStats ? ` (${jobStats.active_workers})` : ''}`} /> + + + {/* Tab 0: Triage */} + + + + + + {loadingTriage && } + {triageResults.length === 0 && !loadingTriage ? ( + No triage results yet. Select a dataset and click "Run Triage". + ) : ( + + + + + RowsRiskVerdict + FindingsMITREModel + + + + {triageResults.map(tr => ( + + {tr.row_start}-{tr.row_end} + + + + {tr.findings?.join('; ') || ''} + + + + {tr.mitre_techniques?.map((t: string, i: number) => ( + + ))} + + + {tr.model_used || ''} + + ))} + +
+
+ )} +
+ + {/* Tab 1: Host Profiles */} + + + + + + {loadingProfiles && } + {profiles.length === 0 && !loadingProfiles ? ( + No host profiles yet. Select a hunt and click "Profile All Hosts". + ) : ( + + {profiles.map(hp => ( + + + + + {hp.hostname} + + + {hp.fqdn && {hp.fqdn}} + + {hp.timeline_summary && ( + + {hp.timeline_summary.slice(0, 300)}{hp.timeline_summary.length > 300 ? '...' : ''} + + )} + {hp.suspicious_findings && hp.suspicious_findings.length > 0 && ( + + + + {hp.suspicious_findings.length} suspicious finding(s) + + + )} + {hp.mitre_techniques && hp.mitre_techniques.length > 0 && ( + + {hp.mitre_techniques.map((t: string, i: number) => ( + + ))} + + )} + + + Model: {hp.model_used || 'N/A'} + + + + ))} + + )} + + + {/* Tab 2: Reports */} + + + + + + {loadingReports && } + {reports.length === 0 && !loadingReports ? ( + No reports yet. Select a hunt and click "Generate Report". + ) : ( + reports.map(rpt => ( + + }> + + + Report - {rpt.status} + {rpt.generation_time_ms && ( + + )} + + + + {rpt.exec_summary && ( + + Executive Summary + {rpt.exec_summary} + + )} + {rpt.findings && rpt.findings.length > 0 && ( + + Findings +
    + {rpt.findings.map((f: any, i: number) => ( +
  • + {typeof f === 'string' ? f : JSON.stringify(f)} +
  • + ))} +
+
+ )} + {rpt.recommendations && rpt.recommendations.length > 0 && ( + + Recommendations +
    + {rpt.recommendations.map((r: any, i: number) => ( +
  • + {typeof r === 'string' ? r : JSON.stringify(r)} +
  • + ))} +
+
+ )} + {rpt.ioc_table && rpt.ioc_table.length > 0 && ( + + IOC Table + + + + + {Object.keys(rpt.ioc_table[0]).map(k => ( + {k} + ))} + + + + {rpt.ioc_table.map((row: any, i: number) => ( + + {Object.values(row).map((v: any, j: number) => ( + {String(v)} + ))} + + ))} + +
+
+
+ )} + {rpt.full_report && ( + + }> + Full Report + + + + {rpt.full_report} + + + + )} + + {rpt.models_used?.map((m: string, i: number) => ( + + ))} + +
+
+ )) + )} +
+ + {/* Tab 3: Anomalies */} + + + + + + {loadingAnomalies && } + {anomalies.length === 0 && !loadingAnomalies ? ( + No anomaly results yet. Select a dataset and click "Detect Anomalies". + ) : ( + <> + + {anomalies.filter(a => a.is_outlier).length} outlier(s) detected out of {anomalies.length} rows + + + + + + RowScore + DistanceClusterOutlier + + + + {anomalies.filter(a => a.is_outlier).concat(anomalies.filter(a => !a.is_outlier).slice(0, 20)).map((a, i) => ( + + {a.row_id ?? ''} + + 0.5 ? 'error' : a.anomaly_score > 0.35 ? 'warning' : 'success'} /> + + {a.distance_from_centroid?.toFixed(4) ?? ''} + + + {a.is_outlier + ? + : } + + + ))} + +
+
+ + )} +
+ + {/* Tab 4: Ask Data (Phase 9) */} + + + + Ask a question about the selected dataset in plain English + + + setQueryText(e.target.value)} + onKeyDown={e => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); doQuery(); } }} + disabled={queryStreaming} + /> + { if (v) setQueryMode(v); }} + > + + Quick + + + Deep + + + {queryStreaming ? ( + + ) : ( + + + + )} + + + + {queryMeta && ( + + Querying {queryMeta.name} ({queryMeta.row_count} rows,{' '} + {queryMeta.sample_rows_shown} sampled) | Mode: {queryMode} + + )} + + {queryStreaming && } + + {queryAnswer && ( + + {queryAnswer} + + )} + + {queryDone && ( + + + + + + + )} + + + {/* Tab 5: Jobs & Load Balancer (Phase 10) */} + + {/* LB Status Cards */} + {lbStatus && ( + + {Object.entries(lbStatus).map(([name, st]) => ( + + + + + {name} + + + + Active: {st.active_jobs} + Done: {st.total_completed} + Errors: {st.total_errors} + Avg: {st.avg_latency_ms.toFixed(0)}ms + + + + + ))} + + )} + + {/* Job queue stats */} + {jobStats && ( + + + + {Object.entries(jobStats.by_status).map(([s, c]) => ( + + ))} + + )} + + + + + + {loadingJobs && } + + {jobs.length === 0 && !loadingJobs ? ( + No jobs yet. Jobs appear here when you trigger triage, profiling, reports, anomaly detection, or data queries. + ) : ( + + + + + Status + Type + Progress + Message + Time + Created + Actions + + + + {jobs.map(j => ( + + + + {statusIcon(j.status)} + {j.status} + + + + + {j.status === 'running' ? ( + + ) : j.status === 'completed' ? ( + 100% + ) : null} + + + {j.error || j.message} + + {fmtMs(j.elapsed_ms)} + {fmtTime(j.created_at)} + + {(j.status === 'queued' || j.status === 'running') && ( + doCancelJob(j.id)}> + + + )} + + + ))} + +
+
+ )} +
+
+ ); +} \ No newline at end of file diff --git a/frontend/src/components/DatasetViewer.tsx b/frontend/src/components/DatasetViewer.tsx index 59420f5..d2f6797 100644 --- a/frontend/src/components/DatasetViewer.tsx +++ b/frontend/src/components/DatasetViewer.tsx @@ -146,6 +146,12 @@ export default function DatasetViewer() { {selected.source_tool && } + {selected.artifact_type && } + {selected.processing_status && selected.processing_status !== 'ready' && ( + + )} {selected.ioc_columns && Object.keys(selected.ioc_columns).length > 0 && ( )} diff --git a/update.md b/update.md new file mode 100644 index 0000000..c278fe2 --- /dev/null +++ b/update.md @@ -0,0 +1,41 @@ +# ThreatHunt Update Log + +## 2026-02-20: Host-Centric Network Map & Analysis Platform + +### Network Map Overhaul +- **Problem**: Network Map showed 409 misclassified "domain" nodes (mostly process names like svchost.exe) and 0 hosts. No deduplication same host counted once per dataset. +- **Root Cause**: IOC column detection misclassified `Fqdn` as "domain" instead of "hostname"; `Name` column (process names) wrongly tagged as "domain" IOC; `ClientId` was in `normalized_columns` as "hostname" but not in `ioc_columns`. +- **Solution**: Created a new host-centric inventory system that scans all datasets, groups by `Fqdn`/`ClientId`, and extracts IPs, users, OS, and network connections. + +#### New Backend Files +- `backend/app/services/host_inventory.py` Deduplicated host inventory builder. Scans all datasets in a hunt, identifies unique hosts via regex-based column detection (`ClientId`, `Fqdn`, `User`/`Username`, `Laddr.IP`/`Raddr.IP`), groups rows, extracts metadata. Filters system accounts (DWM-*, UMFD-*, LOCAL SERVICE, NETWORK SERVICE). Infers OS from hostname patterns (W10-* Windows 10). Builds network connection graph from netstat remote IPs. +- `backend/app/api/routes/network.py` `GET /api/network/host-inventory?hunt_id=X` endpoint returning `{hosts, connections, stats}`. +- `backend/app/services/ioc_extractor.py` IOC extraction service (IP, domain, hash, email, URL patterns). +- `backend/app/services/anomaly_detector.py` Statistical anomaly detection across datasets. +- `backend/app/services/data_query.py` Natural language to structured query translation. +- `backend/app/services/load_balancer.py` Round-robin load balancer for Ollama LLM nodes. +- `backend/app/services/job_queue.py` Async job queue for long-running analysis tasks. +- `backend/app/api/routes/analysis.py` 16 analysis endpoints (IOC extraction, anomaly detection, host profiling, triage, reports, job management). + +#### Modified Backend Files +- `backend/app/main.py` Added `network_router` and `analysis_router` includes. +- `backend/app/db/models.py` Added 4 AI/analysis ORM models (`ProcessingJob`, `AnalysisResult`, `HostProfile`, `IOCEntry`). +- `backend/app/db/engine.py` Connection pool tuning for SQLite async. + +#### Frontend Changes +- `frontend/src/components/NetworkMap.tsx` Complete rewrite: host-centric force-directed graph using Canvas 2D. Two node types (Host / External IP). Shows hostname, IP, OS in labels. Click popover shows FQDN, IPs, OS, logged-in users, datasets, connections. Search across hostname/IP/user/OS. Stats cards showing host counts. +- `frontend/src/components/AnalysisDashboard.tsx` New 6-tab analysis dashboard (IOC Extraction, Anomaly Detection, Host Profiling, Query, Triage, Reports). +- `frontend/src/api/client.ts` Added `network.hostInventory()` method + `InventoryHost`, `InventoryConnection`, `InventoryStats` types. Added analysis API namespace with 16 endpoint methods. +- `frontend/src/App.tsx` Added Analysis Dashboard route and navigation. + +### Results (Radio Hunt 20 Velociraptor datasets, 394K rows) + +| Metric | Before | After | +|--------|--------|-------| +| Nodes shown | 409 misclassified "domains" | **163 unique hosts** | +| Hosts identified | 0 | **163** | +| With IP addresses | N/A | **48** (172.17.x.x LAN) | +| With logged-in users | N/A | **43** (real names only) | +| OS detected | None | **Windows 10** (inferred from hostnames) | +| Deduplication | None (same host 20 datasets) | **Full** (by FQDN/ClientId) | +| System account filtering | None | **DWM-*, UMFD-*, LOCAL/NETWORK SERVICE removed** | \ No newline at end of file