mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
Compare commits
2 Commits
9b98ab9614
...
bb562a91ca
| Author | SHA1 | Date | |
|---|---|---|---|
| bb562a91ca | |||
| 04a9946891 |
@@ -0,0 +1,112 @@
|
||||
"""add processing_status and AI analysis tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 98ab619418bc
|
||||
Create Date: 2026-02-19 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, Sequence[str], None] = "98ab619418bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add columns to datasets table
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready"))
|
||||
batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True))
|
||||
batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True))
|
||||
batch_op.create_index("ix_datasets_status", ["processing_status"])
|
||||
|
||||
# Create triage_results table
|
||||
op.create_table(
|
||||
"triage_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_start", sa.Integer(), nullable=False),
|
||||
sa.Column("row_end", sa.Integer(), nullable=False),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("suspicious_indicators", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create host_profiles table
|
||||
op.create_table(
|
||||
"host_profiles",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("hostname", sa.String(256), nullable=False),
|
||||
sa.Column("fqdn", sa.String(512), nullable=True),
|
||||
sa.Column("client_id", sa.String(64), nullable=True),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"),
|
||||
sa.Column("artifact_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("timeline_summary", sa.Text(), nullable=True),
|
||||
sa.Column("suspicious_findings", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("llm_analysis", sa.Text(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create hunt_reports table
|
||||
op.create_table(
|
||||
"hunt_reports",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("exec_summary", sa.Text(), nullable=True),
|
||||
sa.Column("full_report", sa.Text(), nullable=True),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("recommendations", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_mapping", sa.JSON(), nullable=True),
|
||||
sa.Column("ioc_table", sa.JSON(), nullable=True),
|
||||
sa.Column("host_risk_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("models_used", sa.JSON(), nullable=True),
|
||||
sa.Column("generation_time_ms", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create anomaly_results table
|
||||
op.create_table(
|
||||
"anomaly_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("distance_from_centroid", sa.Float(), nullable=True),
|
||||
sa.Column("cluster_id", sa.Integer(), nullable=True),
|
||||
sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"),
|
||||
sa.Column("explanation", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("anomaly_results")
|
||||
op.drop_table("hunt_reports")
|
||||
op.drop_table("host_profiles")
|
||||
op.drop_table("triage_results")
|
||||
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.drop_index("ix_datasets_status")
|
||||
batch_op.drop_column("file_path")
|
||||
batch_op.drop_column("error_message")
|
||||
batch_op.drop_column("artifact_type")
|
||||
batch_op.drop_column("processing_status")
|
||||
402
backend/app/api/routes/analysis.py
Normal file
402
backend/app/api/routes/analysis.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Analysis API routes - triage, host profiles, reports, IOC extraction,
|
||||
host grouping, anomaly detection, data query (SSE), and job management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import HostProfile, HuntReport, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# --- Response models ---
|
||||
|
||||
class TriageResultResponse(BaseModel):
|
||||
id: str
|
||||
dataset_id: str
|
||||
row_start: int
|
||||
row_end: int
|
||||
risk_score: float
|
||||
verdict: str
|
||||
findings: list | None = None
|
||||
suspicious_indicators: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
model_used: str | None = None
|
||||
node_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HostProfileResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
hostname: str
|
||||
fqdn: str | None = None
|
||||
risk_score: float
|
||||
risk_level: str
|
||||
artifact_summary: dict | None = None
|
||||
timeline_summary: str | None = None
|
||||
suspicious_findings: list | None = None
|
||||
mitre_techniques: list | None = None
|
||||
llm_analysis: str | None = None
|
||||
model_used: str | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class HuntReportResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str
|
||||
status: str
|
||||
exec_summary: str | None = None
|
||||
full_report: str | None = None
|
||||
findings: list | None = None
|
||||
recommendations: list | None = None
|
||||
mitre_mapping: dict | None = None
|
||||
ioc_table: list | None = None
|
||||
host_risk_summary: list | None = None
|
||||
models_used: list | None = None
|
||||
generation_time_ms: int | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
mode: str = "quick" # quick or deep
|
||||
|
||||
|
||||
# --- Triage endpoints ---
|
||||
|
||||
@router.get("/triage/{dataset_id}", response_model=list[TriageResultResponse])
|
||||
async def get_triage_results(
|
||||
dataset_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= min_risk)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/triage/{dataset_id}")
|
||||
async def trigger_triage(
|
||||
dataset_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.triage import triage_dataset
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "triage_started", "dataset_id": dataset_id}
|
||||
|
||||
|
||||
# --- Host profile endpoints ---
|
||||
|
||||
@router.get("/profiles/{hunt_id}", response_model=list[HostProfileResponse])
|
||||
async def get_host_profiles(
|
||||
hunt_id: str,
|
||||
min_risk: float = Query(0.0, ge=0.0, le=10.0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.where(HostProfile.risk_score >= min_risk)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}")
|
||||
async def trigger_all_profiles(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_all_hosts
|
||||
await profile_all_hosts(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
@router.post("/profiles/{hunt_id}/{hostname}")
|
||||
async def trigger_single_profile(
|
||||
hunt_id: str,
|
||||
hostname: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.host_profiler import profile_host
|
||||
await profile_host(hunt_id, hostname)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "profiling_started", "hunt_id": hunt_id, "hostname": hostname}
|
||||
|
||||
|
||||
# --- Report endpoints ---
|
||||
|
||||
@router.get("/reports/{hunt_id}", response_model=list[HuntReportResponse])
|
||||
async def list_reports(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
.order_by(HuntReport.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/reports/{hunt_id}/{report_id}", response_model=HuntReportResponse)
|
||||
async def get_report(
|
||||
hunt_id: str,
|
||||
report_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(HuntReport)
|
||||
.where(HuntReport.id == report_id)
|
||||
.where(HuntReport.hunt_id == hunt_id)
|
||||
)
|
||||
report = result.scalar_one_or_none()
|
||||
if not report:
|
||||
raise HTTPException(status_code=404, detail="Report not found")
|
||||
return report
|
||||
|
||||
|
||||
@router.post("/reports/{hunt_id}/generate")
|
||||
async def trigger_report(
|
||||
hunt_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
):
|
||||
async def _run():
|
||||
from app.services.report_generator import generate_report
|
||||
await generate_report(hunt_id)
|
||||
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "report_generation_started", "hunt_id": hunt_id}
|
||||
|
||||
|
||||
# --- IOC extraction endpoints ---
|
||||
|
||||
@router.get("/iocs/{dataset_id}")
|
||||
async def extract_iocs(
|
||||
dataset_id: str,
|
||||
max_rows: int = Query(5000, ge=1, le=50000),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Extract IOCs (IPs, domains, hashes, etc.) from dataset rows."""
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db, max_rows=max_rows)
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
return {"dataset_id": dataset_id, "iocs": iocs, "total": total}
|
||||
|
||||
|
||||
# --- Host grouping endpoints ---
|
||||
|
||||
@router.get("/hosts/{hunt_id}")
|
||||
async def get_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Group data by hostname across all datasets in a hunt."""
|
||||
from app.services.ioc_extractor import extract_host_groups
|
||||
groups = await extract_host_groups(hunt_id, db)
|
||||
return {"hunt_id": hunt_id, "hosts": groups}
|
||||
|
||||
|
||||
# --- Anomaly detection endpoints ---
|
||||
|
||||
@router.get("/anomalies/{dataset_id}")
|
||||
async def get_anomalies(
|
||||
dataset_id: str,
|
||||
outliers_only: bool = Query(False),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get anomaly detection results for a dataset."""
|
||||
from app.db.models import AnomalyResult
|
||||
stmt = select(AnomalyResult).where(AnomalyResult.dataset_id == dataset_id)
|
||||
if outliers_only:
|
||||
stmt = stmt.where(AnomalyResult.is_outlier == True)
|
||||
stmt = stmt.order_by(AnomalyResult.anomaly_score.desc())
|
||||
result = await db.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": r.id,
|
||||
"dataset_id": r.dataset_id,
|
||||
"row_id": r.row_id,
|
||||
"anomaly_score": r.anomaly_score,
|
||||
"distance_from_centroid": r.distance_from_centroid,
|
||||
"cluster_id": r.cluster_id,
|
||||
"is_outlier": r.is_outlier,
|
||||
"explanation": r.explanation,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.post("/anomalies/{dataset_id}")
|
||||
async def trigger_anomaly_detection(
|
||||
dataset_id: str,
|
||||
k: int = Query(3, ge=2, le=20),
|
||||
threshold: float = Query(0.35, ge=0.1, le=0.9),
|
||||
background_tasks: BackgroundTasks = None,
|
||||
):
|
||||
"""Trigger embedding-based anomaly detection on a dataset."""
|
||||
async def _run():
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(_run)
|
||||
return {"status": "anomaly_detection_started", "dataset_id": dataset_id}
|
||||
else:
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"status": "complete", "dataset_id": dataset_id, "count": len(results)}
|
||||
|
||||
|
||||
# --- Natural language data query (SSE streaming) ---
|
||||
|
||||
@router.post("/query/{dataset_id}")
|
||||
async def query_dataset_endpoint(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Ask a natural language question about a dataset.
|
||||
|
||||
Returns an SSE stream with token-by-token LLM response.
|
||||
Event types: status, metadata, token, error, done
|
||||
"""
|
||||
from app.services.data_query import query_dataset_stream
|
||||
|
||||
return StreamingResponse(
|
||||
query_dataset_stream(dataset_id, body.question, body.mode),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/query/{dataset_id}/sync")
|
||||
async def query_dataset_sync(
|
||||
dataset_id: str,
|
||||
body: QueryRequest,
|
||||
):
|
||||
"""Non-streaming version of data query."""
|
||||
from app.services.data_query import query_dataset
|
||||
|
||||
try:
|
||||
answer = await query_dataset(dataset_id, body.question, body.mode)
|
||||
return {"dataset_id": dataset_id, "question": body.question, "answer": answer, "mode": body.mode}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# --- Job queue endpoints ---
|
||||
|
||||
@router.get("/jobs")
|
||||
async def list_jobs(
|
||||
status: str | None = Query(None),
|
||||
job_type: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
):
|
||||
"""List all tracked jobs."""
|
||||
from app.services.job_queue import job_queue, JobStatus, JobType
|
||||
|
||||
s = JobStatus(status) if status else None
|
||||
t = JobType(job_type) if job_type else None
|
||||
jobs = job_queue.list_jobs(status=s, job_type=t, limit=limit)
|
||||
stats = job_queue.get_stats()
|
||||
return {"jobs": jobs, "stats": stats}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
async def get_job(job_id: str):
|
||||
"""Get status of a specific job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
job = job_queue.get_job(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job.to_dict()
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}")
|
||||
async def cancel_job(job_id: str):
|
||||
"""Cancel a running or queued job."""
|
||||
from app.services.job_queue import job_queue
|
||||
|
||||
if job_queue.cancel_job(job_id):
|
||||
return {"status": "cancelled", "job_id": job_id}
|
||||
raise HTTPException(status_code=400, detail="Job cannot be cancelled (already complete or not found)")
|
||||
|
||||
|
||||
@router.post("/jobs/submit/{job_type}")
|
||||
async def submit_job(
|
||||
job_type: str,
|
||||
params: dict = {},
|
||||
):
|
||||
"""Submit a new job to the queue.
|
||||
|
||||
Job types: triage, host_profile, report, anomaly, query
|
||||
Params vary by type (e.g., dataset_id, hunt_id, question, mode).
|
||||
"""
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
|
||||
try:
|
||||
jt = JobType(job_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||
)
|
||||
|
||||
job = job_queue.submit(jt, **params)
|
||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||
|
||||
|
||||
# --- Load balancer status ---
|
||||
|
||||
@router.get("/lb/status")
|
||||
async def lb_status():
|
||||
"""Get load balancer status for both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
return lb.get_status()
|
||||
|
||||
|
||||
@router.post("/lb/check")
|
||||
async def lb_health_check():
|
||||
"""Force a health check of both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
await lb.check_health()
|
||||
return lb.get_status()
|
||||
28
backend/app/api/routes/network.py
Normal file
28
backend/app/api/routes/network.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Network topology API - host inventory endpoint."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.host_inventory import build_host_inventory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns unique hosts with hostname, IPs, OS, logged-in users, and
|
||||
network connections derived from netstat/connection data.
|
||||
"""
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
if result["stats"]["total_hosts"] == 0:
|
||||
return result
|
||||
return result
|
||||
@@ -3,6 +3,7 @@
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
@@ -12,12 +13,32 @@ from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||
|
||||
_engine_kwargs: dict = dict(
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
||||
_engine_kwargs["pool_size"] = 1
|
||||
_engine_kwargs["max_overflow"] = 0
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
@@ -51,4 +72,4 @@ async def init_db() -> None:
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
await engine.dispose()
|
||||
await engine.dispose()
|
||||
@@ -1,7 +1,7 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, and users.
|
||||
hypotheses, enrichment results, users, and AI analysis tables.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -32,8 +32,7 @@ def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
# ── Users ──────────────────────────────────────────────────────────────
|
||||
|
||||
# -- Users ---
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
@@ -42,17 +41,15 @@ class User(Base):
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
|
||||
|
||||
|
||||
# ── Hunts ──────────────────────────────────────────────────────────────
|
||||
|
||||
# -- Hunts ---
|
||||
|
||||
class Hunt(Base):
|
||||
__tablename__ = "hunts"
|
||||
@@ -60,7 +57,7 @@ class Hunt(Base):
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
|
||||
status: Mapped[str] = mapped_column(String(32), default="active")
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
@@ -69,15 +66,15 @@ class Hunt(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
|
||||
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
host_profiles: Mapped[list["HostProfile"]] = relationship(back_populates="hunt", lazy="noload")
|
||||
reports: Mapped[list["HuntReport"]] = relationship(back_populates="hunt", lazy="noload")
|
||||
|
||||
|
||||
# ── Datasets ───────────────────────────────────────────────────────────
|
||||
|
||||
# -- Datasets ---
|
||||
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
@@ -85,36 +82,44 @@ class Dataset(Base):
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
filename: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
|
||||
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
|
||||
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# New Phase 1-2 columns
|
||||
processing_status: Mapped[str] = mapped_column(String(20), default="ready")
|
||||
artifact_type: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
|
||||
rows: Mapped[list["DatasetRow"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
triage_results: Mapped[list["TriageResult"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_datasets_hunt", "hunt_id"),
|
||||
Index("ix_datasets_status", "processing_status"),
|
||||
)
|
||||
|
||||
|
||||
class DatasetRow(Base):
|
||||
"""Individual row from a CSV dataset, stored as JSON blob."""
|
||||
__tablename__ = "dataset_rows"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -125,7 +130,6 @@ class DatasetRow(Base):
|
||||
data: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# relationships
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(
|
||||
back_populates="row", lazy="noload"
|
||||
@@ -137,8 +141,7 @@ class DatasetRow(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Conversations ─────────────────────────────────────────────────────
|
||||
|
||||
# -- Conversations ---
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
@@ -156,7 +159,6 @@ class Conversation(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
|
||||
@@ -171,16 +173,15 @@ class Message(Base):
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
|
||||
__table_args__ = (
|
||||
@@ -188,8 +189,7 @@ class Message(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Annotations ───────────────────────────────────────────────────────
|
||||
|
||||
# -- Annotations ---
|
||||
|
||||
class Annotation(Base):
|
||||
__tablename__ = "annotations"
|
||||
@@ -205,19 +205,14 @@ class Annotation(Base):
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
severity: Mapped[str] = mapped_column(
|
||||
String(16), default="info"
|
||||
) # info | low | medium | high | critical
|
||||
tag: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), nullable=True
|
||||
) # suspicious | benign | needs-review
|
||||
severity: Mapped[str] = mapped_column(String(16), default="info")
|
||||
tag: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
|
||||
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
|
||||
|
||||
@@ -227,8 +222,7 @@ class Annotation(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Hypotheses ────────────────────────────────────────────────────────
|
||||
|
||||
# -- Hypotheses ---
|
||||
|
||||
class Hypothesis(Base):
|
||||
__tablename__ = "hypotheses"
|
||||
@@ -240,9 +234,7 @@ class Hypothesis(Base):
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), default="draft"
|
||||
) # draft | active | confirmed | rejected
|
||||
status: Mapped[str] = mapped_column(String(16), default="draft")
|
||||
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
@@ -250,7 +242,6 @@ class Hypothesis(Base):
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
|
||||
|
||||
__table_args__ = (
|
||||
@@ -258,21 +249,16 @@ class Hypothesis(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Results ────────────────────────────────────────────────
|
||||
|
||||
# -- Enrichment Results ---
|
||||
|
||||
class EnrichmentResult(Base):
|
||||
__tablename__ = "enrichment_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
|
||||
ioc_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False
|
||||
) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
|
||||
verdict: Mapped[Optional[str]] = mapped_column(
|
||||
String(16), nullable=True
|
||||
) # clean | suspicious | malicious | unknown
|
||||
ioc_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
verdict: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
@@ -287,28 +273,24 @@ class EnrichmentResult(Base):
|
||||
)
|
||||
|
||||
|
||||
# ── AUP Keyword Themes & Keywords ────────────────────────────────────
|
||||
|
||||
# -- AUP Keyword Themes & Keywords ---
|
||||
|
||||
class KeywordTheme(Base):
|
||||
"""A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
|
||||
__tablename__ = "keyword_themes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
keywords: Mapped[list["Keyword"]] = relationship(
|
||||
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Keyword(Base):
|
||||
"""Individual keyword / pattern belonging to a theme."""
|
||||
__tablename__ = "keywords"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
@@ -319,10 +301,102 @@ class Keyword(Base):
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_keywords_theme", "theme_id"),
|
||||
Index("ix_keywords_value", "value"),
|
||||
)
|
||||
|
||||
|
||||
# -- AI Analysis Tables (Phase 2) ---
|
||||
|
||||
class TriageResult(Base):
|
||||
__tablename__ = "triage_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
row_start: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
row_end: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
verdict: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
suspicious_indicators: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="triage_results")
|
||||
|
||||
|
||||
class HostProfile(Base):
|
||||
__tablename__ = "host_profiles"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
hostname: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
fqdn: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
|
||||
client_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
risk_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
risk_level: Mapped[str] = mapped_column(String(20), default="unknown")
|
||||
artifact_summary: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
timeline_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
suspicious_findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
llm_analysis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
hunt: Mapped["Hunt"] = relationship(back_populates="host_profiles")
|
||||
|
||||
|
||||
class HuntReport(Base):
|
||||
__tablename__ = "hunt_reports"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
exec_summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
full_report: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
findings: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
recommendations: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
mitre_mapping: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_table: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
host_risk_summary: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
models_used: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
generation_time_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
hunt: Mapped["Hunt"] = relationship(back_populates="reports")
|
||||
|
||||
|
||||
class AnomalyResult(Base):
|
||||
__tablename__ = "anomaly_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
row_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
@@ -1,10 +1,12 @@
|
||||
"""ThreatHunt backend application.
|
||||
|
||||
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
|
||||
annotation/hypothesis routes. DB tables are auto-created on startup.
|
||||
annotation/hypothesis routes, analysis routes, network routes, job queue,
|
||||
load balancer. DB tables are auto-created on startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
@@ -21,6 +23,8 @@ from app.api.routes.correlation import router as correlation_router
|
||||
from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.network import router as network_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,17 +32,45 @@ logger = logging.getLogger(__name__)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown lifecycle."""
|
||||
logger.info("Starting ThreatHunt API …")
|
||||
logger.info("Starting ThreatHunt API ...")
|
||||
await init_db()
|
||||
logger.info("Database initialised")
|
||||
|
||||
# Ensure uploads directory exists
|
||||
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
||||
logger.info("Upload dir: %s", os.path.abspath(settings.UPLOAD_DIR))
|
||||
|
||||
# Seed default AUP keyword themes
|
||||
from app.db import async_session_factory
|
||||
from app.services.keyword_defaults import seed_defaults
|
||||
async with async_session_factory() as seed_db:
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
|
||||
# Start job queue (Phase 10)
|
||||
from app.services.job_queue import job_queue, register_all_handlers
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
|
||||
# Start load balancer health loop (Phase 10)
|
||||
from app.services.load_balancer import lb
|
||||
await lb.start_health_loop(interval=30.0)
|
||||
logger.info("Load balancer health loop started")
|
||||
|
||||
yield
|
||||
logger.info("Shutting down …")
|
||||
|
||||
logger.info("Shutting down ...")
|
||||
# Stop job queue
|
||||
from app.services.job_queue import job_queue as jq
|
||||
await jq.stop()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
# Stop load balancer
|
||||
from app.services.load_balancer import lb as _lb
|
||||
await _lb.stop_health_loop()
|
||||
logger.info("Load balancer stopped")
|
||||
|
||||
from app.agents.providers_v2 import cleanup_client
|
||||
from app.services.enrichment import enrichment_engine
|
||||
await cleanup_client()
|
||||
@@ -46,15 +78,13 @@ async def lifespan(app: FastAPI):
|
||||
await dispose_db()
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="ThreatHunt API",
|
||||
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
|
||||
version="0.3.0",
|
||||
version=settings.APP_VERSION,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
@@ -74,11 +104,12 @@ app.include_router(enrichment_router)
|
||||
app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(network_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
async def root():
|
||||
"""API health check."""
|
||||
return {
|
||||
"service": "ThreatHunt API",
|
||||
"version": settings.APP_VERSION,
|
||||
@@ -89,4 +120,4 @@ async def root():
|
||||
"roadrunner": settings.roadrunner_url,
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
}
|
||||
199
backend/app/services/anomaly_detector.py
Normal file
199
backend/app/services/anomaly_detector.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Embedding-based anomaly detection using Roadrunner's bge-m3 model.
|
||||
|
||||
Converts dataset rows to embeddings, clusters them, and flags outliers
|
||||
that deviate significantly from the cluster centroids. Uses cosine
|
||||
distance and simple k-means-like centroid computation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import AnomalyResult, Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBED_URL = f"{settings.roadrunner_url}/api/embed"
|
||||
EMBED_MODEL = "bge-m3"
|
||||
BATCH_SIZE = 32 # rows per embedding batch
|
||||
MAX_ROWS = 2000 # cap for anomaly detection
|
||||
|
||||
# --- math helpers (no numpy required) ---
|
||||
|
||||
def _dot(a: list[float], b: list[float]) -> float:
|
||||
return sum(x * y for x, y in zip(a, b))
|
||||
|
||||
|
||||
def _norm(v: list[float]) -> float:
|
||||
return math.sqrt(sum(x * x for x in v))
|
||||
|
||||
|
||||
def _cosine_distance(a: list[float], b: list[float]) -> float:
|
||||
na, nb = _norm(a), _norm(b)
|
||||
if na == 0 or nb == 0:
|
||||
return 1.0
|
||||
return 1.0 - _dot(a, b) / (na * nb)
|
||||
|
||||
|
||||
def _mean_vector(vectors: list[list[float]]) -> list[float]:
|
||||
if not vectors:
|
||||
return []
|
||||
dim = len(vectors[0])
|
||||
n = len(vectors)
|
||||
return [sum(v[i] for v in vectors) / n for i in range(dim)]
|
||||
|
||||
|
||||
def _row_to_text(data: dict) -> str:
|
||||
"""Flatten a row dict to a single string for embedding."""
|
||||
parts = []
|
||||
for k, v in data.items():
|
||||
sv = str(v).strip()
|
||||
if sv and sv.lower() not in ('none', 'null', ''):
|
||||
parts.append(f"{k}={sv}")
|
||||
return " | ".join(parts)[:2000] # cap length
|
||||
|
||||
|
||||
async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
||||
"""Get embeddings from Roadrunner's Ollama API."""
|
||||
resp = await client.post(
|
||||
EMBED_URL,
|
||||
json={"model": EMBED_MODEL, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Ollama returns {"embeddings": [[...], ...]}
|
||||
return data.get("embeddings", [])
|
||||
|
||||
|
||||
def _simple_cluster(
|
||||
embeddings: list[list[float]],
|
||||
k: int = 3,
|
||||
max_iter: int = 20,
|
||||
) -> tuple[list[int], list[list[float]]]:
|
||||
"""Simple k-means clustering (no numpy dependency).
|
||||
|
||||
Returns (assignments, centroids).
|
||||
"""
|
||||
n = len(embeddings)
|
||||
if n <= k:
|
||||
return list(range(n)), embeddings[:]
|
||||
|
||||
# Init centroids: evenly spaced indices
|
||||
step = max(n // k, 1)
|
||||
centroids = [embeddings[i * step % n] for i in range(k)]
|
||||
assignments = [0] * n
|
||||
|
||||
for _ in range(max_iter):
|
||||
# Assign to nearest centroid
|
||||
new_assignments = []
|
||||
for emb in embeddings:
|
||||
dists = [_cosine_distance(emb, c) for c in centroids]
|
||||
new_assignments.append(dists.index(min(dists)))
|
||||
|
||||
if new_assignments == assignments:
|
||||
break
|
||||
assignments = new_assignments
|
||||
|
||||
# Recompute centroids
|
||||
for ci in range(k):
|
||||
members = [embeddings[j] for j in range(n) if assignments[j] == ci]
|
||||
if members:
|
||||
centroids[ci] = _mean_vector(members)
|
||||
|
||||
return assignments, centroids
|
||||
|
||||
|
||||
async def detect_anomalies(
|
||||
dataset_id: str,
|
||||
k: int = 3,
|
||||
outlier_threshold: float = 0.35,
|
||||
) -> list[dict]:
|
||||
"""Run embedding-based anomaly detection on a dataset.
|
||||
|
||||
1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers.
|
||||
"""
|
||||
async with async_session_factory() as db:
|
||||
# Load rows
|
||||
result = await db.execute(
|
||||
select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(MAX_ROWS)
|
||||
)
|
||||
rows = result.all()
|
||||
if not rows:
|
||||
logger.info("No rows for anomaly detection in dataset %s", dataset_id)
|
||||
return []
|
||||
|
||||
row_ids = [r[0] for r in rows]
|
||||
row_indices = [r[1] for r in rows]
|
||||
texts = [_row_to_text(r[2]) for r in rows]
|
||||
|
||||
logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL)
|
||||
|
||||
# Embed in batches
|
||||
all_embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch = texts[i : i + BATCH_SIZE]
|
||||
try:
|
||||
embs = await _embed_batch(batch, client)
|
||||
all_embeddings.extend(embs)
|
||||
except Exception as e:
|
||||
logger.error("Embedding batch %d failed: %s", i, e)
|
||||
# Fill with zeros so indices stay aligned
|
||||
all_embeddings.extend([[0.0] * 1024] * len(batch))
|
||||
|
||||
if not all_embeddings or len(all_embeddings) != len(texts):
|
||||
logger.error("Embedding count mismatch")
|
||||
return []
|
||||
|
||||
# Cluster
|
||||
actual_k = min(k, len(all_embeddings))
|
||||
assignments, centroids = _simple_cluster(all_embeddings, k=actual_k)
|
||||
|
||||
# Compute distances from centroid
|
||||
anomalies: list[dict] = []
|
||||
for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)):
|
||||
dist = _cosine_distance(emb, centroids[cluster_id])
|
||||
is_outlier = dist > outlier_threshold
|
||||
anomalies.append({
|
||||
"row_id": row_ids[idx],
|
||||
"row_index": row_indices[idx],
|
||||
"anomaly_score": round(dist, 4),
|
||||
"distance_from_centroid": round(dist, 4),
|
||||
"cluster_id": cluster_id,
|
||||
"is_outlier": is_outlier,
|
||||
})
|
||||
|
||||
# Save to DB
|
||||
outlier_count = 0
|
||||
for a in anomalies:
|
||||
ar = AnomalyResult(
|
||||
dataset_id=dataset_id,
|
||||
row_id=a["row_id"],
|
||||
anomaly_score=a["anomaly_score"],
|
||||
distance_from_centroid=a["distance_from_centroid"],
|
||||
cluster_id=a["cluster_id"],
|
||||
is_outlier=a["is_outlier"],
|
||||
)
|
||||
db.add(ar)
|
||||
if a["is_outlier"]:
|
||||
outlier_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)",
|
||||
len(anomalies), outlier_count, outlier_threshold,
|
||||
)
|
||||
|
||||
return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True)
|
||||
81
backend/app/services/artifact_classifier.py
Normal file
81
backend/app/services/artifact_classifier.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Artifact classifier - identify Velociraptor artifact types from CSV headers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# (required_columns, artifact_type)
|
||||
FINGERPRINTS: list[tuple[set[str], str]] = [
|
||||
({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"),
|
||||
({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"),
|
||||
({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"),
|
||||
({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"),
|
||||
({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"),
|
||||
({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"),
|
||||
({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"),
|
||||
({"Key", "Type", "Value"}, "Windows.Registry.Finder"),
|
||||
({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"),
|
||||
({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"),
|
||||
({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"),
|
||||
({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"),
|
||||
({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"),
|
||||
({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"),
|
||||
({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"),
|
||||
({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"),
|
||||
({"os_info.hostname", "os_info.system"}, "Server.Information.Client"),
|
||||
({"ClientId", "os_info.fqdn"}, "Server.Information.Client"),
|
||||
({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"),
|
||||
({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"),
|
||||
({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"),
|
||||
({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"),
|
||||
({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"),
|
||||
({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"),
|
||||
]
|
||||
|
||||
VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"}
|
||||
|
||||
CATEGORY_MAP = {
|
||||
"Pslist": "process",
|
||||
"Netstat": "network",
|
||||
"Services": "persistence",
|
||||
"FileFinder": "filesystem",
|
||||
"Prefetch": "execution",
|
||||
"Registry": "persistence",
|
||||
"EvtxHunter": "eventlog",
|
||||
"EventLogs": "eventlog",
|
||||
"Autoruns": "persistence",
|
||||
"DNS": "network",
|
||||
"Hash": "filesystem",
|
||||
"TaskScheduler": "persistence",
|
||||
"Users": "account",
|
||||
"Client": "system",
|
||||
"WMI": "persistence",
|
||||
"ArpCache": "network",
|
||||
"BrowserHistory": "application",
|
||||
}
|
||||
|
||||
|
||||
def classify_artifact(columns: list[str]) -> str:
|
||||
col_set = set(columns)
|
||||
for required, artifact_type in FINGERPRINTS:
|
||||
if required.issubset(col_set):
|
||||
return artifact_type
|
||||
if VELOCIRAPTOR_META.intersection(col_set):
|
||||
return "Velociraptor.Unknown"
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_artifact_category(artifact_type: str) -> str:
|
||||
for key, category in CATEGORY_MAP.items():
|
||||
if key.lower() in artifact_type.lower():
|
||||
return category
|
||||
return "unknown"
|
||||
238
backend/app/services/data_query.py
Normal file
238
backend/app/services/data_query.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Natural-language data query service with SSE streaming.
|
||||
|
||||
Lets analysts ask questions about dataset rows in plain English.
|
||||
Routes to fast model (Roadrunner) for quick queries, heavy model (Wile)
|
||||
for deep analysis. Supports streaming via OllamaProvider.generate_stream().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum rows to include in context window
|
||||
MAX_CONTEXT_ROWS = 60
|
||||
MAX_ROW_TEXT_CHARS = 300
|
||||
|
||||
|
||||
def _rows_to_text(rows: list[dict], columns: list[str]) -> str:
|
||||
"""Convert dataset rows to a compact text table for the LLM context."""
|
||||
if not rows:
|
||||
return "(no rows)"
|
||||
# Header
|
||||
header = " | ".join(columns[:20]) # cap columns to avoid overflow
|
||||
lines = [header, "-" * min(len(header), 120)]
|
||||
for row in rows[:MAX_CONTEXT_ROWS]:
|
||||
vals = []
|
||||
for c in columns[:20]:
|
||||
v = str(row.get(c, ""))
|
||||
if len(v) > 80:
|
||||
v = v[:77] + "..."
|
||||
vals.append(v)
|
||||
line = " | ".join(vals)
|
||||
if len(line) > MAX_ROW_TEXT_CHARS:
|
||||
line = line[:MAX_ROW_TEXT_CHARS] + "..."
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt.
|
||||
You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.).
|
||||
|
||||
Your job:
|
||||
- Answer the analyst's question about this data accurately and concisely
|
||||
- Point out suspicious patterns, anomalies, or indicators of compromise
|
||||
- Reference MITRE ATT&CK techniques when relevant
|
||||
- Suggest follow-up queries or pivots
|
||||
- If you cannot answer from the data provided, say so clearly
|
||||
|
||||
Rules:
|
||||
- Be factual - only reference data you can see
|
||||
- Use forensic terminology appropriate for SOC/DFIR analysts
|
||||
- Format your answer with clear sections using markdown
|
||||
- If the data seems benign, say so - do not fabricate threats"""
|
||||
|
||||
|
||||
async def _load_dataset_context(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
sample_size: int = MAX_CONTEXT_ROWS,
|
||||
) -> tuple[dict, str, int]:
|
||||
"""Load dataset metadata + sample rows for context.
|
||||
|
||||
Returns (metadata_dict, rows_text, total_row_count).
|
||||
"""
|
||||
ds = await db.get(Dataset, dataset_id)
|
||||
if not ds:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
# Get total count
|
||||
count_q = await db.execute(
|
||||
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
total = count_q.scalar() or 0
|
||||
|
||||
# Sample rows - get first batch + some from the middle
|
||||
half = sample_size // 2
|
||||
result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(half)
|
||||
)
|
||||
first_rows = result.scalars().all()
|
||||
|
||||
# If dataset is large, also sample from the middle
|
||||
middle_rows = []
|
||||
if total > sample_size:
|
||||
mid_offset = total // 2
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(mid_offset)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
else:
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(half)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
|
||||
all_rows = first_rows + middle_rows
|
||||
row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows]
|
||||
|
||||
columns = list(ds.column_schema.keys()) if ds.column_schema else []
|
||||
if not columns and row_dicts:
|
||||
columns = list(row_dicts[0].keys())
|
||||
|
||||
rows_text = _rows_to_text(row_dicts, columns)
|
||||
|
||||
metadata = {
|
||||
"name": ds.name,
|
||||
"filename": ds.filename,
|
||||
"source_tool": ds.source_tool,
|
||||
"artifact_type": getattr(ds, "artifact_type", None),
|
||||
"row_count": total,
|
||||
"columns": columns[:30],
|
||||
"sample_rows_shown": len(all_rows),
|
||||
}
|
||||
return metadata, rows_text, total
|
||||
|
||||
|
||||
async def query_dataset(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> str:
|
||||
"""Non-streaming query: returns full answer text."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
)
|
||||
return result.get("response", "No response generated.")
|
||||
|
||||
|
||||
async def query_dataset_stream(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> AsyncIterator[str]:
|
||||
"""Streaming query: yields SSE-formatted events."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
# Send initial metadata event
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n"
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
model_name = settings.DEFAULT_HEAVY_MODEL
|
||||
node_name = "wile"
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
model_name = settings.DEFAULT_FAST_MODEL
|
||||
node_name = "roadrunner"
|
||||
|
||||
# Stream tokens
|
||||
token_count = 0
|
||||
try:
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
):
|
||||
token_count += 1
|
||||
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n"
|
||||
|
||||
|
||||
def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str:
|
||||
"""Construct the full prompt with data context."""
|
||||
parts = [
|
||||
f"## Dataset: {meta['name']}",
|
||||
f"- Source: {meta.get('source_tool', 'unknown')}",
|
||||
f"- Artifact type: {meta.get('artifact_type', 'unknown')}",
|
||||
f"- Total rows: {total}",
|
||||
f"- Columns: {', '.join(meta.get('columns', []))}",
|
||||
f"- Showing {meta['sample_rows_shown']} sample rows below",
|
||||
"",
|
||||
"## Sample Data",
|
||||
"```",
|
||||
rows_text,
|
||||
"```",
|
||||
"",
|
||||
f"## Analyst Question",
|
||||
question,
|
||||
]
|
||||
return "\n".join(parts)
|
||||
290
backend/app/services/host_inventory.py
Normal file
290
backend/app/services/host_inventory.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Host Inventory Service - builds a deduplicated host-centric network view.
|
||||
|
||||
Scans all datasets in a hunt to identify unique hosts, their IPs, OS,
|
||||
logged-in users, and network connections between them.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Column-name patterns (Velociraptor + generic forensic tools) ---
|
||||
|
||||
_HOST_ID_RE = re.compile(
|
||||
r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I)
|
||||
_FQDN_RE = re.compile(
|
||||
r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|'
|
||||
r'system_?name|machine_?name|nodename|workstation)$', re.I)
|
||||
_USERNAME_RE = re.compile(
|
||||
r'^(user|username|user_?name|logon_?name|account_?name|owner|'
|
||||
r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I)
|
||||
_LOCAL_IP_RE = re.compile(
|
||||
r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I)
|
||||
_REMOTE_IP_RE = re.compile(
|
||||
r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I)
|
||||
_REMOTE_PORT_RE = re.compile(
|
||||
r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I)
|
||||
_OS_RE = re.compile(
|
||||
r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I)
|
||||
_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$')
|
||||
|
||||
_IGNORE_IPS = frozenset({
|
||||
'0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null',
|
||||
})
|
||||
_SYSTEM_DOMAINS = frozenset({
|
||||
'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER',
|
||||
})
|
||||
_SYSTEM_USERS = frozenset({
|
||||
'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE',
|
||||
'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3',
|
||||
})
|
||||
|
||||
|
||||
def _is_valid_ip(v: str) -> bool:
|
||||
if not v or v in _IGNORE_IPS:
|
||||
return False
|
||||
return bool(_IP_VALID_RE.match(v))
|
||||
|
||||
|
||||
def _clean(v: Any) -> str:
|
||||
s = str(v or '').strip()
|
||||
return s if s and s not in ('-', 'None', 'null', '') else ''
|
||||
|
||||
|
||||
_SYSTEM_USER_RE = re.compile(
|
||||
r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I)
|
||||
|
||||
|
||||
def _extract_username(raw: str) -> str:
|
||||
"""Clean username, stripping domain prefixes and filtering system accounts."""
|
||||
if not raw:
|
||||
return ''
|
||||
name = raw.strip()
|
||||
if '\\' in name:
|
||||
domain, _, name = name.rpartition('\\')
|
||||
name = name.strip()
|
||||
if domain.strip().upper() in _SYSTEM_DOMAINS:
|
||||
if not name or _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
if _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
return name or ''
|
||||
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
return 'Windows 10'
|
||||
if 'W11-' in u or 'WIN11' in u:
|
||||
return 'Windows 11'
|
||||
if 'W7-' in u or 'WIN7' in u:
|
||||
return 'Windows 7'
|
||||
if 'SRV' in u or 'SERVER' in u or 'DC-' in u:
|
||||
return 'Windows Server'
|
||||
if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')):
|
||||
return 'Linux'
|
||||
if 'MAC' in u or 'DARWIN' in u:
|
||||
return 'macOS'
|
||||
return 'Windows'
|
||||
|
||||
|
||||
def _identify_columns(ds: Dataset) -> dict:
|
||||
norm = ds.normalized_columns or {}
|
||||
schema = ds.column_schema or {}
|
||||
raw_cols = list(schema.keys()) if schema else list(norm.keys())
|
||||
|
||||
result = {
|
||||
'host_id': [], 'fqdn': [], 'username': [],
|
||||
'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [],
|
||||
}
|
||||
|
||||
for col in raw_cols:
|
||||
canonical = (norm.get(col) or '').lower()
|
||||
lower = col.lower()
|
||||
|
||||
if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')):
|
||||
result['host_id'].append(col)
|
||||
|
||||
if _FQDN_RE.match(lower) or canonical == 'fqdn':
|
||||
result['fqdn'].append(col)
|
||||
|
||||
if _USERNAME_RE.match(lower) or canonical in ('username', 'user'):
|
||||
result['username'].append(col)
|
||||
|
||||
if _LOCAL_IP_RE.match(lower):
|
||||
result['local_ip'].append(col)
|
||||
elif _REMOTE_IP_RE.match(lower):
|
||||
result['remote_ip'].append(col)
|
||||
|
||||
if _REMOTE_PORT_RE.match(lower):
|
||||
result['remote_port'].append(col)
|
||||
|
||||
if _OS_RE.match(lower) or canonical == 'os':
|
||||
result['os'].append(col)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns dict with 'hosts', 'connections', and 'stats'.
|
||||
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
|
||||
"""
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
all_datasets = ds_result.scalars().all()
|
||||
|
||||
if not all_datasets:
|
||||
return {"hosts": [], "connections": [], "stats": {
|
||||
"total_hosts": 0, "total_datasets_scanned": 0,
|
||||
"total_rows_scanned": 0,
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
while True:
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
fqdn = _clean(data.get(c))
|
||||
if fqdn:
|
||||
break
|
||||
client_id = ''
|
||||
for c in cols['host_id']:
|
||||
client_id = _clean(data.get(c))
|
||||
if client_id:
|
||||
break
|
||||
|
||||
if not fqdn and not client_id:
|
||||
continue
|
||||
|
||||
host_key = fqdn or client_id
|
||||
|
||||
if host_key not in hosts:
|
||||
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
|
||||
hosts[host_key] = {
|
||||
'id': host_key,
|
||||
'hostname': short or client_id,
|
||||
'fqdn': fqdn,
|
||||
'client_id': client_id,
|
||||
'ips': set(),
|
||||
'os': '',
|
||||
'users': set(),
|
||||
'datasets': set(),
|
||||
'row_count': 0,
|
||||
}
|
||||
|
||||
h = hosts[host_key]
|
||||
h['datasets'].add(ds.name)
|
||||
h['row_count'] += 1
|
||||
if client_id and not h['client_id']:
|
||||
h['client_id'] = client_id
|
||||
|
||||
for c in cols['username']:
|
||||
u = _extract_username(_clean(data.get(c)))
|
||||
if u:
|
||||
h['users'].add(u)
|
||||
|
||||
for c in cols['local_ip']:
|
||||
ip = _clean(data.get(c))
|
||||
if _is_valid_ip(ip):
|
||||
h['ips'].add(ip)
|
||||
ip_to_host[ip] = host_key
|
||||
|
||||
for c in cols['os']:
|
||||
ov = _clean(data.get(c))
|
||||
if ov and not h['os']:
|
||||
h['os'] = ov
|
||||
|
||||
for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
|
||||
offset += batch_size
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
h['os'] = _infer_os(h['fqdn'])
|
||||
h['ips'] = sorted(h['ips'])
|
||||
h['users'] = sorted(h['users'])
|
||||
h['datasets'] = sorted(h['datasets'])
|
||||
|
||||
# Build connections, resolving IPs to host keys
|
||||
conn_list = []
|
||||
seen = set()
|
||||
for (src, dst_ip, dst_port), cnt in connections.items():
|
||||
if dst_ip in _IGNORE_IPS:
|
||||
continue
|
||||
dst_host = ip_to_host.get(dst_ip, '')
|
||||
if dst_host == src:
|
||||
continue
|
||||
key = tuple(sorted([src, dst_host or dst_ip]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
conn_list.append({
|
||||
'source': src,
|
||||
'target': dst_host or dst_ip,
|
||||
'target_ip': dst_ip,
|
||||
'port': dst_port,
|
||||
'count': cnt,
|
||||
})
|
||||
|
||||
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
|
||||
|
||||
return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
},
|
||||
}
|
||||
198
backend/app/services/host_profiler.py
Normal file
198
backend/app/services/host_profiler.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Host profiler - per-host deep threat analysis via Wile heavy models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= 3.0)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
.limit(10)
|
||||
)
|
||||
triages = result.scalars().all()
|
||||
if not triages:
|
||||
return "No significant triage findings."
|
||||
lines = []
|
||||
for t in triages:
|
||||
lines.append(
|
||||
f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} "
|
||||
f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_data: dict[str, list[dict]] = {}
|
||||
triage_parts: list[str] = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
matching = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
row_host = (
|
||||
data.get("hostname", "") or data.get("Fqdn", "")
|
||||
or data.get("ClientId", "") or data.get("client_id", "")
|
||||
)
|
||||
if hostname.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
elif fqdn and fqdn.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
|
||||
if matching:
|
||||
host_data[artifact_type] = matching[:50]
|
||||
triage_info = await _get_triage_summary(db, ds.id)
|
||||
triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}")
|
||||
|
||||
return {
|
||||
"artifacts": host_data,
|
||||
"triage_summary": "\n".join(triage_parts) or "No triage data.",
|
||||
"artifact_count": sum(len(v) for v in host_data.values()),
|
||||
}
|
||||
|
||||
|
||||
async def profile_host(
|
||||
hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None,
|
||||
) -> None:
|
||||
logger.info("Profiling host %s in hunt %s", hostname, hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
host_data = await _collect_host_data(db, hunt_id, hostname, fqdn)
|
||||
if host_data["artifact_count"] == 0:
|
||||
logger.info("No data found for host %s, skipping", hostname)
|
||||
return
|
||||
|
||||
system_prompt = (
|
||||
"You are a senior threat hunting analyst performing deep host analysis.\n"
|
||||
"You receive consolidated forensic artifacts and prior triage results for a single host.\n\n"
|
||||
"Provide a comprehensive host threat profile as JSON:\n"
|
||||
"- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n"
|
||||
"- risk_level: low/medium/high/critical\n"
|
||||
"- suspicious_findings: list of specific concerns\n"
|
||||
"- mitre_techniques: list of MITRE ATT&CK technique IDs\n"
|
||||
"- timeline_summary: brief timeline of suspicious activity\n"
|
||||
"- analysis: detailed narrative assessment\n\n"
|
||||
"Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n"
|
||||
"Respond with valid JSON only."
|
||||
)
|
||||
|
||||
artifact_summary = {}
|
||||
for art_type, rows in host_data["artifacts"].items():
|
||||
artifact_summary[art_type] = [
|
||||
{k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20]
|
||||
]
|
||||
|
||||
prompt = (
|
||||
f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n"
|
||||
f"## Prior Triage Results\n{host_data['triage_summary']}\n\n"
|
||||
f"## Artifact Data ({host_data['artifact_count']} total rows)\n"
|
||||
f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n"
|
||||
"Provide your comprehensive host threat profile as JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(
|
||||
WILE_URL,
|
||||
json={
|
||||
"model": HEAVY_MODEL,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 4096},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
llm_text = resp.json().get("response", "")
|
||||
|
||||
from app.services.triage import _parse_llm_response
|
||||
parsed = _parse_llm_response(llm_text)
|
||||
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id,
|
||||
hostname=hostname,
|
||||
fqdn=fqdn,
|
||||
client_id=client_id,
|
||||
risk_score=float(parsed.get("risk_score", 0.0)),
|
||||
risk_level=parsed.get("risk_level", "low"),
|
||||
artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()},
|
||||
timeline_summary=parsed.get("timeline_summary", ""),
|
||||
suspicious_findings=parsed.get("suspicious_findings", []),
|
||||
mitre_techniques=parsed.get("mitre_techniques", []),
|
||||
llm_analysis=parsed.get("analysis", llm_text[:5000]),
|
||||
model_used=HEAVY_MODEL,
|
||||
node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
llm_analysis=f"Error: {e}",
|
||||
model_used=HEAVY_MODEL, node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def profile_all_hosts(hunt_id: str) -> None:
|
||||
logger.info("Starting host profiling for hunt %s", hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
hostnames: dict[str, str | None] = {}
|
||||
for ds in datasets:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000)
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname")
|
||||
if host and str(host).strip():
|
||||
h = str(host).strip()
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
async def _bounded(hostname: str, fqdn: str | None):
|
||||
async with semaphore:
|
||||
await profile_host(hunt_id, hostname, fqdn)
|
||||
|
||||
tasks = [_bounded(h, f) for h, f in hostnames.items()]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames))
|
||||
210
backend/app/services/ioc_extractor.py
Normal file
210
backend/app/services/ioc_extractor.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""IOC extraction service extract indicators of compromise from dataset rows.
|
||||
|
||||
Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes,
|
||||
email addresses, URLs, and file paths that look suspicious.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns
|
||||
|
||||
_IPV4 = re.compile(
|
||||
r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b'
|
||||
)
|
||||
_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b')
|
||||
_DOMAIN = re.compile(
|
||||
r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)'
|
||||
r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|'
|
||||
r'online|site|club|win|work|download|stream|gdn|bid|review|racing|'
|
||||
r'loan|date|faith|accountant|cricket|science|trade|party|men)\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b')
|
||||
_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b')
|
||||
_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b')
|
||||
_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b')
|
||||
_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE)
|
||||
|
||||
# Private / reserved IPs to skip
|
||||
_PRIVATE_NETS = re.compile(
|
||||
r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)'
|
||||
)
|
||||
|
||||
PATTERNS = {
|
||||
'ipv4': _IPV4,
|
||||
'ipv6': _IPV6,
|
||||
'domain': _DOMAIN,
|
||||
'md5': _MD5,
|
||||
'sha1': _SHA1,
|
||||
'sha256': _SHA256,
|
||||
'email': _EMAIL,
|
||||
'url': _URL,
|
||||
}
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
return bool(_PRIVATE_NETS.match(ip))
|
||||
|
||||
|
||||
def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]:
|
||||
"""Extract all IOC types from a block of text."""
|
||||
result: dict[str, set[str]] = defaultdict(set)
|
||||
for ioc_type, pattern in PATTERNS.items():
|
||||
for match in pattern.findall(text):
|
||||
val = match.strip().lower() if ioc_type != 'url' else match.strip()
|
||||
# Filter private IPs
|
||||
if ioc_type == 'ipv4' and skip_private and _is_private_ip(val):
|
||||
continue
|
||||
# Filter hex strings that are too generic (< 32 chars not a hash)
|
||||
result[ioc_type].add(val)
|
||||
return result
|
||||
|
||||
|
||||
async def extract_iocs_from_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
max_rows: int = 5000,
|
||||
skip_private: bool = True,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Extract IOCs from all rows of a dataset.
|
||||
|
||||
Returns {ioc_type: [sorted unique values]}.
|
||||
"""
|
||||
# Load rows in batches
|
||||
all_iocs: dict[str, set[str]] = defaultdict(set)
|
||||
offset = 0
|
||||
batch_size = 500
|
||||
|
||||
while offset < max_rows:
|
||||
result = await db.execute(
|
||||
select(DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for data in rows:
|
||||
# Flatten all values to a single string for scanning
|
||||
text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data)
|
||||
batch_iocs = extract_iocs_from_text(text, skip_private)
|
||||
for ioc_type, values in batch_iocs.items():
|
||||
all_iocs[ioc_type].update(values)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Convert sets to sorted lists
|
||||
return {k: sorted(v) for k, v in all_iocs.items() if v}
|
||||
|
||||
|
||||
async def extract_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Group all data by hostname across datasets in a hunt.
|
||||
|
||||
Returns a list of host group dicts with dataset count, total rows,
|
||||
artifact types, and time range.
|
||||
"""
|
||||
# Get all datasets for this hunt
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
ds_list = result.scalars().all()
|
||||
if not ds_list:
|
||||
return []
|
||||
|
||||
# Known host columns (check normalized data first, then raw)
|
||||
HOST_COLS = [
|
||||
'hostname', 'host', 'computer_name', 'computername', 'system',
|
||||
'machine', 'device_name', 'devicename', 'endpoint',
|
||||
'ClientId', 'Fqdn', 'client_id', 'fqdn',
|
||||
]
|
||||
|
||||
hosts: dict[str, dict] = {}
|
||||
|
||||
for ds in ds_list:
|
||||
# Sample first few rows to find host column
|
||||
sample_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.limit(5)
|
||||
)
|
||||
samples = sample_result.all()
|
||||
if not samples:
|
||||
continue
|
||||
|
||||
# Find which host column exists
|
||||
host_col = None
|
||||
for row_data, norm_data in samples:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
for col in HOST_COLS:
|
||||
if col in check and check[col]:
|
||||
host_col = col
|
||||
break
|
||||
if host_col:
|
||||
break
|
||||
|
||||
if not host_col:
|
||||
continue
|
||||
|
||||
# Count rows per host in this dataset
|
||||
all_rows_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
)
|
||||
all_rows = all_rows_result.all()
|
||||
for row_data, norm_data in all_rows:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
host_val = check.get(host_col, '')
|
||||
if not host_val or not isinstance(host_val, str):
|
||||
continue
|
||||
host_val = host_val.strip()
|
||||
if not host_val:
|
||||
continue
|
||||
|
||||
if host_val not in hosts:
|
||||
hosts[host_val] = {
|
||||
'hostname': host_val,
|
||||
'dataset_ids': set(),
|
||||
'total_rows': 0,
|
||||
'artifact_types': set(),
|
||||
'first_seen': None,
|
||||
'last_seen': None,
|
||||
}
|
||||
hosts[host_val]['dataset_ids'].add(ds.id)
|
||||
hosts[host_val]['total_rows'] += 1
|
||||
if ds.artifact_type:
|
||||
hosts[host_val]['artifact_types'].add(ds.artifact_type)
|
||||
|
||||
# Convert to output format
|
||||
result_list = []
|
||||
for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True):
|
||||
result_list.append({
|
||||
'hostname': h['hostname'],
|
||||
'dataset_count': len(h['dataset_ids']),
|
||||
'total_rows': h['total_rows'],
|
||||
'artifact_types': sorted(h['artifact_types']),
|
||||
'first_seen': None, # TODO: extract from timestamp columns
|
||||
'last_seen': None,
|
||||
'risk_score': None, # TODO: link to host profiles
|
||||
})
|
||||
|
||||
return result_list
|
||||
316
backend/app/services/job_queue.py
Normal file
316
backend/app/services/job_queue.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
TRIAGE = "triage"
|
||||
HOST_PROFILE = "host_profile"
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
id: str
|
||||
job_type: JobType
|
||||
status: JobStatus = JobStatus.QUEUED
|
||||
progress: float = 0.0 # 0-100
|
||||
message: str = ""
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
started_at: float | None = None
|
||||
completed_at: float | None = None
|
||||
params: dict = field(default_factory=dict)
|
||||
_cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
end = self.completed_at or time.time()
|
||||
start = self.started_at or self.created_at
|
||||
return int((end - start) * 1000)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"job_type": self.job_type.value,
|
||||
"status": self.status.value,
|
||||
"progress": round(self.progress, 1),
|
||||
"message": self.message,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
"elapsed_ms": self.elapsed_ms,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def cancel(self):
|
||||
self._cancel_event.set()
|
||||
self.status = JobStatus.CANCELLED
|
||||
self.completed_at = time.time()
|
||||
self.message = "Cancelled by user"
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._max_workers = max_workers
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
return job
|
||||
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
return False
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
if job_type:
|
||||
jobs = [j for j in jobs if j.job_type == job_type]
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
return {
|
||||
"total": len(self._jobs),
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
}
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
job = self._jobs.get(job_id)
|
||||
if not job or job.is_cancelled:
|
||||
continue
|
||||
|
||||
handler = self._handlers.get(job.job_type)
|
||||
if not handler:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"No handler for {job.job_type.value}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(f"No handler for job type {job.job_type.value}")
|
||||
continue
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
job.message = "Running..."
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
result = await handler(job)
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.progress = 100.0
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
if hostname:
|
||||
job.message = f"Profiling host {hostname}"
|
||||
await profile_host(hunt_id, hostname)
|
||||
return {"hostname": hostname}
|
||||
else:
|
||||
job.message = f"Profiling all hosts in hunt {hunt_id}"
|
||||
await profile_all_hosts(hunt_id)
|
||||
return {"hunt_id": hunt_id}
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
report = await generate_report(hunt_id)
|
||||
return {"report_id": report.id if report else None}
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
threshold = job.params.get("threshold", 0.35)
|
||||
job.message = f"Detecting anomalies in dataset {dataset_id}"
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
mode = job.params.get("mode", "quick")
|
||||
job.message = f"Querying dataset {dataset_id}"
|
||||
answer = await query_dataset(dataset_id, question, mode)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
193
backend/app/services/load_balancer.py
Normal file
193
backend/app/services/load_balancer.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Smart load balancer for Wile & Roadrunner LLM nodes.
|
||||
|
||||
Tracks active jobs per node, health status, and routes new work
|
||||
to the least-busy healthy node. Periodically pings both nodes
|
||||
to maintain an up-to-date health map.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeId(str, Enum):
|
||||
WILE = "wile"
|
||||
ROADRUNNER = "roadrunner"
|
||||
|
||||
|
||||
class WorkloadTier(str, Enum):
|
||||
"""What kind of workload is this?"""
|
||||
HEAVY = "heavy" # 70B models, deep analysis, reports
|
||||
FAST = "fast" # 7-14B models, triage, quick queries
|
||||
EMBEDDING = "embed" # bge-m3 embeddings
|
||||
ANY = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeStatus:
|
||||
node_id: NodeId
|
||||
url: str
|
||||
healthy: bool = True
|
||||
last_check: float = 0.0
|
||||
active_jobs: int = 0
|
||||
total_completed: int = 0
|
||||
total_errors: int = 0
|
||||
avg_latency_ms: float = 0.0
|
||||
_latencies: list[float] = field(default_factory=list)
|
||||
|
||||
def record_completion(self, latency_ms: float):
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
self.total_completed += 1
|
||||
self._latencies.append(latency_ms)
|
||||
# Rolling average of last 50
|
||||
if len(self._latencies) > 50:
|
||||
self._latencies = self._latencies[-50:]
|
||||
self.avg_latency_ms = sum(self._latencies) / len(self._latencies)
|
||||
|
||||
def record_error(self):
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
self.total_errors += 1
|
||||
|
||||
def record_start(self):
|
||||
self.active_jobs += 1
|
||||
|
||||
|
||||
class LoadBalancer:
|
||||
"""Routes LLM work to the least-busy healthy node.
|
||||
|
||||
Node capabilities:
|
||||
- Wile: Heavy models (70B), code models (32B)
|
||||
- Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision
|
||||
"""
|
||||
|
||||
# Which nodes can handle which tiers
|
||||
TIER_NODES = {
|
||||
WorkloadTier.HEAVY: [NodeId.WILE],
|
||||
WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE],
|
||||
WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER],
|
||||
WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._nodes: dict[NodeId, NodeStatus] = {
|
||||
NodeId.WILE: NodeStatus(
|
||||
node_id=NodeId.WILE,
|
||||
url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}",
|
||||
),
|
||||
NodeId.ROADRUNNER: NodeStatus(
|
||||
node_id=NodeId.ROADRUNNER,
|
||||
url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}",
|
||||
),
|
||||
}
|
||||
self._lock = asyncio.Lock()
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start_health_loop(self, interval: float = 30.0):
|
||||
"""Start background health-check loop."""
|
||||
if self._health_task and not self._health_task.done():
|
||||
return
|
||||
self._health_task = asyncio.create_task(self._health_loop(interval))
|
||||
logger.info("Load balancer health loop started (%.0fs interval)", interval)
|
||||
|
||||
async def stop_health_loop(self):
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_task = None
|
||||
|
||||
async def _health_loop(self, interval: float):
|
||||
while True:
|
||||
try:
|
||||
await self.check_health()
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
async def check_health(self):
|
||||
"""Ping both nodes and update status."""
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
for nid, status in self._nodes.items():
|
||||
try:
|
||||
resp = await client.get(f"{status.url}/api/tags")
|
||||
status.healthy = resp.status_code == 200
|
||||
except Exception:
|
||||
status.healthy = False
|
||||
status.last_check = time.time()
|
||||
logger.debug(
|
||||
f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} "
|
||||
f"(active={status.active_jobs})"
|
||||
)
|
||||
|
||||
def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
||||
"""Select the best node for a workload tier.
|
||||
|
||||
Strategy: among healthy nodes that support the tier,
|
||||
pick the one with fewest active jobs.
|
||||
Falls back to any node if none healthy.
|
||||
"""
|
||||
candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE])
|
||||
|
||||
# Filter to healthy candidates
|
||||
healthy = [
|
||||
nid for nid in candidates
|
||||
if self._nodes[nid].healthy
|
||||
]
|
||||
|
||||
if not healthy:
|
||||
logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate")
|
||||
healthy = candidates
|
||||
|
||||
# Pick least busy
|
||||
best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs)
|
||||
return best
|
||||
|
||||
def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
||||
"""Select node and mark a job started."""
|
||||
node = self.select_node(tier)
|
||||
self._nodes[node].record_start()
|
||||
logger.info(
|
||||
f"LB: dispatched {tier.value} -> {node.value} "
|
||||
f"(active={self._nodes[node].active_jobs})"
|
||||
)
|
||||
return node
|
||||
|
||||
def release(self, node: NodeId, latency_ms: float = 0, error: bool = False):
|
||||
"""Mark a job completed on a node."""
|
||||
status = self._nodes.get(node)
|
||||
if not status:
|
||||
return
|
||||
if error:
|
||||
status.record_error()
|
||||
else:
|
||||
status.record_completion(latency_ms)
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current load balancer status."""
|
||||
return {
|
||||
nid.value: {
|
||||
"healthy": s.healthy,
|
||||
"active_jobs": s.active_jobs,
|
||||
"total_completed": s.total_completed,
|
||||
"total_errors": s.total_errors,
|
||||
"avg_latency_ms": round(s.avg_latency_ms, 1),
|
||||
"last_check": s.last_check,
|
||||
}
|
||||
for nid, s in self._nodes.items()
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
lb = LoadBalancer()
|
||||
198
backend/app/services/report_generator.py
Normal file
198
backend/app/services/report_generator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Report generator - debate-powered hunt report generation using Wile + Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import (
|
||||
Dataset, HostProfile, HuntReport, TriageResult,
|
||||
)
|
||||
from app.services.triage import _parse_llm_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
|
||||
|
||||
async def _llm_call(url: str, model: str, system: str, prompt: str, timeout: float = 300.0) -> str:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 8192},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("response", "")
|
||||
|
||||
|
||||
async def _gather_evidence(db, hunt_id: str) -> dict:
|
||||
ds_result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = ds_result.scalars().all()
|
||||
|
||||
dataset_summary = []
|
||||
all_triage = []
|
||||
for ds in datasets:
|
||||
ds_info = {
|
||||
"name": ds.name,
|
||||
"artifact_type": getattr(ds, "artifact_type", "Unknown"),
|
||||
"row_count": ds.row_count or 0,
|
||||
}
|
||||
dataset_summary.append(ds_info)
|
||||
|
||||
triage_result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == ds.id)
|
||||
.where(TriageResult.risk_score >= 3.0)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
.limit(15)
|
||||
)
|
||||
for t in triage_result.scalars().all():
|
||||
all_triage.append({
|
||||
"dataset": ds.name,
|
||||
"artifact_type": ds_info["artifact_type"],
|
||||
"rows": f"{t.row_start}-{t.row_end}",
|
||||
"risk_score": t.risk_score,
|
||||
"verdict": t.verdict,
|
||||
"findings": t.findings[:5] if t.findings else [],
|
||||
"indicators": t.suspicious_indicators[:5] if t.suspicious_indicators else [],
|
||||
"mitre": t.mitre_techniques or [],
|
||||
})
|
||||
|
||||
profile_result = await db.execute(
|
||||
select(HostProfile)
|
||||
.where(HostProfile.hunt_id == hunt_id)
|
||||
.order_by(HostProfile.risk_score.desc())
|
||||
)
|
||||
profiles = profile_result.scalars().all()
|
||||
host_summaries = []
|
||||
for p in profiles:
|
||||
host_summaries.append({
|
||||
"hostname": p.hostname,
|
||||
"risk_score": p.risk_score,
|
||||
"risk_level": p.risk_level,
|
||||
"findings": p.suspicious_findings[:5] if p.suspicious_findings else [],
|
||||
"mitre": p.mitre_techniques or [],
|
||||
"timeline": (p.timeline_summary or "")[:300],
|
||||
})
|
||||
|
||||
return {
|
||||
"datasets": dataset_summary,
|
||||
"triage_findings": all_triage[:30],
|
||||
"host_profiles": host_summaries,
|
||||
"total_datasets": len(datasets),
|
||||
"total_rows": sum(d["row_count"] for d in dataset_summary),
|
||||
"high_risk_hosts": len([h for h in host_summaries if h["risk_score"] >= 7.0]),
|
||||
}
|
||||
|
||||
|
||||
async def generate_report(hunt_id: str) -> None:
|
||||
logger.info("Generating report for hunt %s", hunt_id)
|
||||
start = time.monotonic()
|
||||
|
||||
async with async_session() as db:
|
||||
report = HuntReport(
|
||||
hunt_id=hunt_id,
|
||||
status="generating",
|
||||
models_used=[HEAVY_MODEL, FAST_MODEL],
|
||||
)
|
||||
db.add(report)
|
||||
await db.commit()
|
||||
await db.refresh(report)
|
||||
report_id = report.id
|
||||
|
||||
try:
|
||||
evidence = await _gather_evidence(db, hunt_id)
|
||||
evidence_text = json.dumps(evidence, indent=1, default=str)[:12000]
|
||||
|
||||
# Phase 1: Wile initial analysis
|
||||
logger.info("Report phase 1: Wile initial analysis")
|
||||
phase1 = await _llm_call(
|
||||
WILE_URL, HEAVY_MODEL,
|
||||
system=(
|
||||
"You are a senior threat intelligence analyst writing a hunt report.\n"
|
||||
"Analyze all evidence and produce a structured threat assessment.\n"
|
||||
"Include: executive summary, detailed findings per host, MITRE mapping,\n"
|
||||
"IOC table, risk rankings, and actionable recommendations.\n"
|
||||
"Use markdown formatting. Be thorough and specific."
|
||||
),
|
||||
prompt=f"Hunt evidence:\n{evidence_text}\n\nProduce your initial threat assessment.",
|
||||
)
|
||||
|
||||
# Phase 2: Roadrunner critical review
|
||||
logger.info("Report phase 2: Roadrunner critical review")
|
||||
phase2 = await _llm_call(
|
||||
ROADRUNNER_URL, FAST_MODEL,
|
||||
system=(
|
||||
"You are a critical reviewer of threat hunt reports.\n"
|
||||
"Review the initial assessment and identify:\n"
|
||||
"- Missing correlations or overlooked indicators\n"
|
||||
"- False positive risks or overblown findings\n"
|
||||
"- Additional MITRE techniques that should be mapped\n"
|
||||
"- Gaps in recommendations\n"
|
||||
"Be specific and constructive. Respond in markdown."
|
||||
),
|
||||
prompt=f"Evidence:\n{evidence_text[:4000]}\n\nInitial Assessment:\n{phase1[:6000]}\n\nProvide your critical review.",
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
# Phase 3: Wile final synthesis
|
||||
logger.info("Report phase 3: Wile final synthesis")
|
||||
synthesis_prompt = (
|
||||
f"Original evidence:\n{evidence_text[:6000]}\n\n"
|
||||
f"Initial assessment:\n{phase1[:5000]}\n\n"
|
||||
f"Critical review:\n{phase2[:3000]}\n\n"
|
||||
"Produce the FINAL hunt report incorporating the review feedback.\n"
|
||||
"Return JSON with these keys:\n"
|
||||
"- executive_summary: 2-3 paragraph executive summary\n"
|
||||
"- findings: list of {title, severity, description, evidence, mitre_ids}\n"
|
||||
"- recommendations: list of {priority, action, rationale}\n"
|
||||
"- mitre_mapping: dict of technique_id -> {name, description, evidence}\n"
|
||||
"- ioc_table: list of {type, value, context, confidence}\n"
|
||||
"- host_risk_summary: list of {hostname, risk_score, risk_level, key_findings}\n"
|
||||
"Respond with valid JSON only."
|
||||
)
|
||||
phase3_text = await _llm_call(
|
||||
WILE_URL, HEAVY_MODEL,
|
||||
system="You are producing the final, definitive threat hunt report. Incorporate all feedback. Respond with valid JSON only.",
|
||||
prompt=synthesis_prompt,
|
||||
)
|
||||
|
||||
parsed = _parse_llm_response(phase3_text)
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
full_report = f"# Threat Hunt Report\n\n{phase1}\n\n---\n## Review Notes\n{phase2}\n\n---\n## Final Synthesis\n{phase3_text}"
|
||||
|
||||
report.status = "complete"
|
||||
report.exec_summary = parsed.get("executive_summary", phase1[:2000])
|
||||
report.full_report = full_report
|
||||
report.findings = parsed.get("findings", [])
|
||||
report.recommendations = parsed.get("recommendations", [])
|
||||
report.mitre_mapping = parsed.get("mitre_mapping", {})
|
||||
report.ioc_table = parsed.get("ioc_table", [])
|
||||
report.host_risk_summary = parsed.get("host_risk_summary", [])
|
||||
report.generation_time_ms = elapsed_ms
|
||||
await db.commit()
|
||||
|
||||
logger.info("Report %s complete in %dms", report_id, elapsed_ms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Report generation failed for hunt %s: %s", hunt_id, e)
|
||||
report.status = "error"
|
||||
report.exec_summary = f"Report generation failed: {e}"
|
||||
report.generation_time_ms = int((time.monotonic() - start) * 1000)
|
||||
await db.commit()
|
||||
170
backend/app/services/triage.py
Normal file
170
backend/app/services/triage.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import Dataset, DatasetRow, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
|
||||
ARTIFACT_FOCUS = {
|
||||
"Windows.System.Pslist": "Look for: suspicious parent-child, LOLBins, unsigned, injection indicators, abnormal paths.",
|
||||
"Windows.Network.Netstat": "Look for: C2 beaconing, unusual ports, connections to rare IPs, non-browser high-port listeners.",
|
||||
"Windows.System.Services": "Look for: services in temp dirs, misspelled names, unsigned ServiceDll, unusual start modes.",
|
||||
"Windows.Forensics.Prefetch": "Look for: recon tools, lateral movement tools, rarely-run executables with high run counts.",
|
||||
"Windows.EventLogs.EvtxHunter": "Look for: logon type 10/3 anomalies, service installs, PowerShell script blocks, clearing.",
|
||||
"Windows.Sys.Autoruns": "Look for: recently added entries, entries in temp/user dirs, encoded commands, suspicious DLLs.",
|
||||
"Windows.Registry.Finder": "Look for: run keys, image file execution options, hidden services, encoded payloads.",
|
||||
"Windows.Search.FileFinder": "Look for: files in unusual locations, recently modified system files, known tool names.",
|
||||
}
|
||||
|
||||
|
||||
def _parse_llm_response(text: str) -> dict:
|
||||
text = text.strip()
|
||||
fence = re.search(r"`(?:json)?\s*\n?(.*?)\n?\s*`", text, re.DOTALL)
|
||||
if fence:
|
||||
text = fence.group(1).strip()
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
brace = text.find("{")
|
||||
bracket = text.rfind("}")
|
||||
if brace != -1 and bracket != -1 and bracket > brace:
|
||||
try:
|
||||
return json.loads(text[brace : bracket + 1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return {"raw_response": text[:3000]}
|
||||
|
||||
|
||||
async def triage_dataset(dataset_id: str) -> None:
|
||||
logger.info("Starting triage for dataset %s", dataset_id)
|
||||
|
||||
async with async_session() as db:
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id)
|
||||
)
|
||||
dataset = ds_result.scalar_one_or_none()
|
||||
if not dataset:
|
||||
logger.error("Dataset %s not found", dataset_id)
|
||||
return
|
||||
|
||||
artifact_type = getattr(dataset, "artifact_type", None) or "Unknown"
|
||||
focus = ARTIFACT_FOCUS.get(artifact_type, "Analyze for any suspicious indicators.")
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
total_rows = count_result.scalar() or 0
|
||||
|
||||
batch_size = settings.TRIAGE_BATCH_SIZE
|
||||
suspicious_count = 0
|
||||
offset = 0
|
||||
|
||||
while offset < total_rows:
|
||||
if suspicious_count >= settings.TRIAGE_MAX_SUSPICIOUS_ROWS:
|
||||
logger.info("Reached suspicious row cap for dataset %s", dataset_id)
|
||||
break
|
||||
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_number)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
batch_data = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
compact = {k: str(v)[:200] for k, v in data.items() if v}
|
||||
batch_data.append(compact)
|
||||
|
||||
system_prompt = f"""You are a cybersecurity triage analyst. Analyze this batch of {artifact_type} forensic data.
|
||||
{focus}
|
||||
|
||||
Return JSON with:
|
||||
- risk_score: 0.0 (benign) to 10.0 (critical threat)
|
||||
- verdict: "clean", "suspicious", "malicious", or "inconclusive"
|
||||
- findings: list of key observations
|
||||
- suspicious_indicators: list of specific IOCs or anomalies
|
||||
- mitre_techniques: list of MITRE ATT&CK IDs if applicable
|
||||
|
||||
Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."""
|
||||
|
||||
prompt = f"Rows {offset+1}-{offset+len(rows)} of {total_rows}:\n{json.dumps(batch_data, default=str)[:6000]}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(
|
||||
ROADRUNNER_URL,
|
||||
json={
|
||||
"model": DEFAULT_FAST_MODEL,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.2, "num_predict": 2048},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
llm_text = result.get("response", "")
|
||||
|
||||
parsed = _parse_llm_response(llm_text)
|
||||
risk = float(parsed.get("risk_score", 0.0))
|
||||
|
||||
triage = TriageResult(
|
||||
dataset_id=dataset_id,
|
||||
row_start=offset,
|
||||
row_end=offset + len(rows) - 1,
|
||||
risk_score=risk,
|
||||
verdict=parsed.get("verdict", "inconclusive"),
|
||||
findings=parsed.get("findings", []),
|
||||
suspicious_indicators=parsed.get("suspicious_indicators", []),
|
||||
mitre_techniques=parsed.get("mitre_techniques", []),
|
||||
model_used=DEFAULT_FAST_MODEL,
|
||||
node_used="roadrunner",
|
||||
)
|
||||
db.add(triage)
|
||||
await db.commit()
|
||||
|
||||
if risk >= settings.TRIAGE_ESCALATION_THRESHOLD:
|
||||
suspicious_count += len(rows)
|
||||
|
||||
logger.debug(
|
||||
"Triage batch %d-%d: risk=%.1f verdict=%s",
|
||||
offset, offset + len(rows) - 1, risk, triage.verdict,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Triage batch %d failed: %s", offset, e)
|
||||
triage = TriageResult(
|
||||
dataset_id=dataset_id,
|
||||
row_start=offset,
|
||||
row_end=offset + len(rows) - 1,
|
||||
risk_score=0.0,
|
||||
verdict="error",
|
||||
findings=[f"Error: {e}"],
|
||||
model_used=DEFAULT_FAST_MODEL,
|
||||
node_used="roadrunner",
|
||||
)
|
||||
db.add(triage)
|
||||
await db.commit()
|
||||
|
||||
offset += batch_size
|
||||
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
@@ -13,5 +13,5 @@ if __name__ == "__main__":
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True,
|
||||
)
|
||||
reload=False,
|
||||
)
|
||||
8
backend/scan_cols.py
Normal file
8
backend/scan_cols.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import json, urllib.request
|
||||
url = "http://localhost:8000/api/datasets?skip=0&limit=20&hunt_id=fd8ba3fb45de4d65bea072f73d80544d"
|
||||
data = json.loads(urllib.request.urlopen(url).read())
|
||||
for d in data["datasets"]:
|
||||
ioc = list((d["ioc_columns"] or {}).items())
|
||||
norm = d.get("normalized_columns") or {}
|
||||
hc = {k: v for k, v in norm.items() if v in ("hostname", "fqdn", "username", "src_ip", "dst_ip", "ip_address", "os")}
|
||||
print(d["name"], "|", d["row_count"], "|", ioc, "|", hc)
|
||||
23
backend/scan_rows.py
Normal file
23
backend/scan_rows.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import json, urllib.request
|
||||
|
||||
def get(path):
|
||||
return json.loads(urllib.request.urlopen("http://localhost:8000" + path).read())
|
||||
|
||||
# Check ip_to_hostname_mapping
|
||||
ds_list = get("/api/datasets?skip=0&limit=20&hunt_id=fd8ba3fb45de4d65bea072f73d80544d")
|
||||
for d in ds_list["datasets"]:
|
||||
if d["name"] == "ip_to_hostname_mapping":
|
||||
rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=5")
|
||||
print("=== ip_to_hostname_mapping ===")
|
||||
for r in rows["rows"]:
|
||||
print(r)
|
||||
if d["name"] == "Netstat":
|
||||
rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=3")
|
||||
print("=== Netstat ===")
|
||||
for r in rows["rows"]:
|
||||
print(r)
|
||||
if d["name"] == "netstat_enrich2":
|
||||
rows = get(f"/api/datasets/{d['id']}/rows?offset=0&limit=3")
|
||||
print("=== netstat_enrich2 ===")
|
||||
for r in rows["rows"]:
|
||||
print(r)
|
||||
BIN
backend/threathunt.db-shm
Normal file
BIN
backend/threathunt.db-shm
Normal file
Binary file not shown.
0
backend/threathunt.db-wal
Normal file
0
backend/threathunt.db-wal
Normal file
@@ -15,10 +15,10 @@ server {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_read_timeout 120s;
|
||||
proxy_read_timeout 300s;
|
||||
}
|
||||
|
||||
# SPA fallback — serve index.html for all non-file routes
|
||||
# SPA fallback serve index.html for all non-file routes
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
@@ -28,4 +28,4 @@ server {
|
||||
expires 1y;
|
||||
add_header Cache-Control "public, immutable";
|
||||
}
|
||||
}
|
||||
}
|
||||
30
frontend/package-lock.json
generated
30
frontend/package-lock.json
generated
@@ -66,7 +66,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@babel/core/-/core-7.29.0.tgz",
|
||||
"integrity": "sha512-CGOfOJqWjg2qW/Mb6zNsDm+u5vFQ8DxXfbM09z69p5Z6+mE1ikP2jUXw+j42Pf1XTYED2Rni5f95npYeuwMDQA==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.29.0",
|
||||
"@babel/generator": "^7.29.0",
|
||||
@@ -722,7 +721,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@babel/plugin-syntax-flow/-/plugin-syntax-flow-7.28.6.tgz",
|
||||
"integrity": "sha512-D+OrJumc9McXNEBI/JmFnc/0uCM2/Y3PEBG3gfV3QIYkKv5pvnpzFrl1kYCrcHJP8nOeFB/SHi1IHz29pNGuew==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/helper-plugin-utils": "^7.28.6"
|
||||
},
|
||||
@@ -1606,7 +1604,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@babel/plugin-transform-react-jsx/-/plugin-transform-react-jsx-7.28.6.tgz",
|
||||
"integrity": "sha512-61bxqhiRfAACulXSLd/GxqmAedUSrRZIu/cbaT18T1CetkTmtDN15it7i80ru4DVqRK1WMxQhXs+Lf9kajm5Ow==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/helper-annotate-as-pure": "^7.27.3",
|
||||
"@babel/helper-module-imports": "^7.28.6",
|
||||
@@ -2448,7 +2445,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@emotion/react/-/react-11.14.0.tgz",
|
||||
"integrity": "sha512-O000MLDBDdk/EohJPFUqvnp4qnHeYkVP5B0xEG0D/L7cOKP9kefu2DXn8dj74cQfsEzUqh+sr1RzFqiL1o+PpA==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.18.3",
|
||||
"@emotion/babel-plugin": "^11.13.5",
|
||||
@@ -2492,7 +2488,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@emotion/styled/-/styled-11.14.1.tgz",
|
||||
"integrity": "sha512-qEEJt42DuToa3gurlH4Qqc1kVpNq8wO8cJtDzU46TjlzWjDlsVyevtYCRijVq3SrHsROS+gVQ8Fnea108GnKzw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.18.3",
|
||||
"@emotion/babel-plugin": "^11.13.5",
|
||||
@@ -3074,7 +3069,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@mui/material/-/material-7.3.8.tgz",
|
||||
"integrity": "sha512-QKd1RhDXE1hf2sQDNayA9ic9jGkEgvZOf0tTkJxlBPG8ns8aS4rS8WwYURw2x5y3739p0HauUXX9WbH7UufFLw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.28.6",
|
||||
"@mui/core-downloads-tracker": "^7.3.8",
|
||||
@@ -3185,7 +3179,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@mui/system/-/system-7.3.8.tgz",
|
||||
"integrity": "sha512-hoFRj4Zw2Km8DPWZp/nKG+ao5Jw5LSk2m/e4EGc6M3RRwXKEkMSG4TgtfVJg7dS2homRwtdXSMW+iRO0ZJ4+IA==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.28.6",
|
||||
"@mui/private-theming": "^7.3.8",
|
||||
@@ -4127,7 +4120,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.3.28.tgz",
|
||||
"integrity": "sha512-z9VXpC7MWrhfWipitjNdgCauoMLRdIILQsAEV+ZesIzBq/oUlxk0m3ApZuMFCXdnS4U7KrI+l3WRUEGQ8K1QKw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"csstype": "^3.2.2"
|
||||
@@ -4283,7 +4275,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.62.0.tgz",
|
||||
"integrity": "sha512-TiZzBSJja/LbhNPvk6yc0JrX9XqhQ0hdh6M2svYfsHGejaKFIAGd9MQ+ERIMzLGlN/kZoYIgdxFV0PuljTKXag==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/regexpp": "^4.4.0",
|
||||
"@typescript-eslint/scope-manager": "5.62.0",
|
||||
@@ -4337,7 +4328,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.62.0.tgz",
|
||||
"integrity": "sha512-VlJEV0fOQ7BExOsHYAGrgbEiZoi8D+Bl2+f6V2RrXerRSylnp+ZBHmPvaIa8cz0Ajx7WO7Z5RqfgYg7ED1nRhA==",
|
||||
"license": "BSD-2-Clause",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "5.62.0",
|
||||
"@typescript-eslint/types": "5.62.0",
|
||||
@@ -4707,7 +4697,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.16.0.tgz",
|
||||
"integrity": "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -4806,7 +4795,6 @@
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
|
||||
"integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.1",
|
||||
"fast-json-stable-stringify": "^2.0.0",
|
||||
@@ -5719,7 +5707,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"baseline-browser-mapping": "^2.9.0",
|
||||
"caniuse-lite": "^1.0.30001759",
|
||||
@@ -6768,8 +6755,7 @@
|
||||
"version": "3.2.3",
|
||||
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.2.3.tgz",
|
||||
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/damerau-levenshtein": {
|
||||
"version": "1.0.8",
|
||||
@@ -7562,7 +7548,6 @@
|
||||
"integrity": "sha512-ypowyDxpVSYpkXr9WPv2PAZCtNip1Mv5KTW0SCurXv/9iOpcrH9PaqUElksqEB6pChqHGDRCFTyrZlGhnLNGiA==",
|
||||
"deprecated": "This version is no longer supported. Please see https://eslint.org/version-support for other options.",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.2.0",
|
||||
"@eslint-community/regexpp": "^4.6.1",
|
||||
@@ -10327,7 +10312,6 @@
|
||||
"resolved": "https://registry.npmjs.org/jest/-/jest-27.5.1.tgz",
|
||||
"integrity": "sha512-Yn0mADZB89zTtjkPJEXwrac3LHudkQMR+Paqa8uxJHCBr9agxztUifWCyiYrjhMPBoUVBjyny0I7XH6ozDr7QQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jest/core": "^27.5.1",
|
||||
"import-local": "^3.0.2",
|
||||
@@ -11225,7 +11209,6 @@
|
||||
"resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.7.tgz",
|
||||
"integrity": "sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"jiti": "bin/jiti.js"
|
||||
}
|
||||
@@ -12603,7 +12586,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"nanoid": "^3.3.11",
|
||||
"picocolors": "^1.1.1",
|
||||
@@ -13738,7 +13720,6 @@
|
||||
"resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz",
|
||||
"integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"cssesc": "^3.0.0",
|
||||
"util-deprecate": "^1.0.2"
|
||||
@@ -14104,7 +14085,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz",
|
||||
"integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0"
|
||||
},
|
||||
@@ -14239,7 +14219,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz",
|
||||
"integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"loose-envify": "^1.1.0",
|
||||
"scheduler": "^0.23.2"
|
||||
@@ -14265,7 +14244,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
|
||||
"integrity": "sha512-F27qZr8uUqwhWZboondsPx8tnC3Ct3SxZA3V5WyEvujRyyNv0VYPhoBg1gZ8/MV5tubQp76Trw8lTv9hzRBa+A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -14762,7 +14740,6 @@
|
||||
"resolved": "https://registry.npmjs.org/rollup/-/rollup-2.79.2.tgz",
|
||||
"integrity": "sha512-fS6iqSPZDs3dr/y7Od6y5nha8dW1YnbgtsyotCVvoFGKbERG++CVRFv1meyGDE1SNItQA8BrnCw7ScdAhRJ3XQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"rollup": "dist/bin/rollup"
|
||||
},
|
||||
@@ -15008,7 +14985,6 @@
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
"fast-uri": "^3.0.1",
|
||||
@@ -16391,7 +16367,6 @@
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz",
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -16668,7 +16643,6 @@
|
||||
"resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
|
||||
"integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -16995,7 +16969,6 @@
|
||||
"resolved": "https://registry.npmjs.org/webpack/-/webpack-5.105.2.tgz",
|
||||
"integrity": "sha512-dRXm0a2qcHPUBEzVk8uph0xWSjV/xZxenQQbLwnwP7caQCYpqG1qddwlyEkIDkYn0K8tvmcrZ+bOrzoQ3HxCDw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@types/eslint-scope": "^3.7.7",
|
||||
"@types/estree": "^1.0.8",
|
||||
@@ -17480,7 +17453,6 @@
|
||||
"resolved": "https://registry.npmjs.org/ajv/-/ajv-8.18.0.tgz",
|
||||
"integrity": "sha512-PlXPeEWMXMZ7sPYOHqmDyCJzcfNrUr3fGNKtezX14ykXOEIvyK81d+qydx89KY5O71FKMPaQ2vBfBFI5NHR63A==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.3",
|
||||
"fast-uri": "^3.0.1",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* ThreatHunt — MUI-powered analyst-assist platform.
|
||||
* ThreatHunt MUI-powered analyst-assist platform.
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback } from 'react';
|
||||
@@ -18,6 +18,7 @@ import ScienceIcon from '@mui/icons-material/Science';
|
||||
import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
|
||||
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
|
||||
import HubIcon from '@mui/icons-material/Hub';
|
||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||
import { SnackbarProvider } from 'notistack';
|
||||
import theme from './theme';
|
||||
|
||||
@@ -32,6 +33,7 @@ import HypothesisTracker from './components/HypothesisTracker';
|
||||
import CorrelationView from './components/CorrelationView';
|
||||
import AUPScanner from './components/AUPScanner';
|
||||
import NetworkMap from './components/NetworkMap';
|
||||
import AnalysisDashboard from './components/AnalysisDashboard';
|
||||
|
||||
const DRAWER_WIDTH = 240;
|
||||
|
||||
@@ -42,13 +44,14 @@ const NAV: NavItem[] = [
|
||||
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
|
||||
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
|
||||
{ label: 'Upload', path: '/upload', icon: <UploadFileIcon /> },
|
||||
{ label: 'AI Analysis', path: '/analysis', icon: <AssessmentIcon /> },
|
||||
{ label: 'Agent', path: '/agent', icon: <SmartToyIcon /> },
|
||||
{ label: 'Enrichment', path: '/enrichment', icon: <SecurityIcon /> },
|
||||
{ label: 'Annotations', path: '/annotations', icon: <BookmarkIcon /> },
|
||||
{ label: 'Hypotheses', path: '/hypotheses', icon: <ScienceIcon /> },
|
||||
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
|
||||
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||
];
|
||||
|
||||
function Shell() {
|
||||
@@ -109,6 +112,7 @@ function Shell() {
|
||||
<Route path="/hunts" element={<HuntManager />} />
|
||||
<Route path="/datasets" element={<DatasetViewer />} />
|
||||
<Route path="/upload" element={<FileUpload />} />
|
||||
<Route path="/analysis" element={<AnalysisDashboard />} />
|
||||
<Route path="/agent" element={<AgentPanel />} />
|
||||
<Route path="/enrichment" element={<EnrichmentPanel />} />
|
||||
<Route path="/annotations" element={<AnnotationPanel />} />
|
||||
@@ -135,4 +139,4 @@ function App() {
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
export default App;
|
||||
@@ -1,11 +1,11 @@
|
||||
/* ====================================================================
|
||||
ThreatHunt API Client — mirrors every backend endpoint.
|
||||
ThreatHunt API Client -- mirrors every backend endpoint.
|
||||
All requests go through the CRA proxy (see package.json "proxy").
|
||||
==================================================================== */
|
||||
|
||||
const BASE = ''; // proxied to http://localhost:8000 by CRA
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
// -- Helpers --
|
||||
|
||||
let authToken: string | null = localStorage.getItem('th_token');
|
||||
|
||||
@@ -36,7 +36,7 @@ async function api<T = any>(
|
||||
return res.text() as unknown as T;
|
||||
}
|
||||
|
||||
// ── Auth ─────────────────────────────────────────────────────────────
|
||||
// -- Auth --
|
||||
|
||||
export interface UserPayload {
|
||||
id: string; username: string; email: string;
|
||||
@@ -63,7 +63,7 @@ export const auth = {
|
||||
me: () => api<UserPayload>('/api/auth/me'),
|
||||
};
|
||||
|
||||
// ── Hunts ────────────────────────────────────────────────────────────
|
||||
// -- Hunts --
|
||||
|
||||
export interface Hunt {
|
||||
id: string; name: string; description: string | null; status: string;
|
||||
@@ -82,7 +82,7 @@ export const hunts = {
|
||||
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Datasets ─────────────────────────────────────────────────────────
|
||||
// -- Datasets --
|
||||
|
||||
export interface DatasetSummary {
|
||||
id: string; name: string; filename: string; source_tool: string | null;
|
||||
@@ -92,6 +92,8 @@ export interface DatasetSummary {
|
||||
file_size_bytes: number; encoding: string | null; delimiter: string | null;
|
||||
time_range_start: string | null; time_range_end: string | null;
|
||||
hunt_id: string | null; created_at: string;
|
||||
processing_status?: string; artifact_type?: string | null;
|
||||
error_message?: string | null; file_path?: string | null;
|
||||
}
|
||||
|
||||
export interface UploadResult {
|
||||
@@ -155,7 +157,7 @@ export const datasets = {
|
||||
delete: (id: string) => api(`/api/datasets/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Agent ────────────────────────────────────────────────────────────
|
||||
// -- Agent --
|
||||
|
||||
export interface AssistRequest {
|
||||
query: string;
|
||||
@@ -198,7 +200,7 @@ export const agent = {
|
||||
},
|
||||
};
|
||||
|
||||
// ── Annotations ──────────────────────────────────────────────────────
|
||||
// -- Annotations --
|
||||
|
||||
export interface AnnotationData {
|
||||
id: string; row_id: number | null; dataset_id: string | null;
|
||||
@@ -224,7 +226,7 @@ export const annotations = {
|
||||
delete: (id: string) => api(`/api/annotations/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Hypotheses ───────────────────────────────────────────────────────
|
||||
// -- Hypotheses --
|
||||
|
||||
export interface HypothesisData {
|
||||
id: string; hunt_id: string | null; title: string; description: string | null;
|
||||
@@ -249,7 +251,7 @@ export const hypotheses = {
|
||||
delete: (id: string) => api(`/api/hypotheses/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Enrichment ───────────────────────────────────────────────────────
|
||||
// -- Enrichment --
|
||||
|
||||
export interface EnrichmentResult {
|
||||
ioc_value: string; ioc_type: string; source: string; verdict: string;
|
||||
@@ -274,7 +276,7 @@ export const enrichment = {
|
||||
status: () => api<Record<string, any>>('/api/enrichment/status'),
|
||||
};
|
||||
|
||||
// ── Correlation ──────────────────────────────────────────────────────
|
||||
// -- Correlation --
|
||||
|
||||
export interface CorrelationResult {
|
||||
hunt_ids: string[]; summary: string; total_correlations: number;
|
||||
@@ -292,7 +294,7 @@ export const correlation = {
|
||||
api<{ ioc_value: string; occurrences: any[]; total: number }>(`/api/correlation/ioc/${encodeURIComponent(ioc_value)}`),
|
||||
};
|
||||
|
||||
// ── Reports ──────────────────────────────────────────────────────────
|
||||
// -- Reports --
|
||||
|
||||
export const reports = {
|
||||
json: (huntId: string) =>
|
||||
@@ -305,13 +307,13 @@ export const reports = {
|
||||
api<Record<string, any>>(`/api/reports/hunt/${huntId}/summary`),
|
||||
};
|
||||
|
||||
// ── Root / misc ──────────────────────────────────────────────────────
|
||||
// -- Root / misc --
|
||||
|
||||
export const misc = {
|
||||
root: () => api<{ name: string; version: string; status: string }>('/'),
|
||||
};
|
||||
|
||||
// ── AUP Keywords ─────────────────────────────────────────────────────
|
||||
// -- AUP Keywords --
|
||||
|
||||
export interface KeywordOut {
|
||||
id: number; theme_id: string; value: string; is_regex: boolean; created_at: string;
|
||||
@@ -368,3 +370,216 @@ export const keywords = {
|
||||
quickScan: (datasetId: string) =>
|
||||
api<ScanResponse>(`/api/keywords/scan/quick?dataset_id=${encodeURIComponent(datasetId)}`),
|
||||
};
|
||||
|
||||
|
||||
// -- Analysis (Phase 2+) --
|
||||
|
||||
export interface TriageResultData {
|
||||
id: string; dataset_id: string; row_start: number; row_end: number;
|
||||
risk_score: number; verdict: string;
|
||||
findings: any[] | null; suspicious_indicators: any[] | null;
|
||||
mitre_techniques: any[] | null;
|
||||
model_used: string | null; node_used: string | null;
|
||||
}
|
||||
|
||||
export interface HostProfileData {
|
||||
id: string; hunt_id: string; hostname: string; fqdn: string | null;
|
||||
risk_score: number; risk_level: string;
|
||||
artifact_summary: Record<string, any> | null;
|
||||
timeline_summary: string | null;
|
||||
suspicious_findings: any[] | null;
|
||||
mitre_techniques: any[] | null;
|
||||
llm_analysis: string | null;
|
||||
model_used: string | null;
|
||||
}
|
||||
|
||||
export interface HuntReportData {
|
||||
id: string; hunt_id: string; status: string;
|
||||
exec_summary: string | null; full_report: string | null;
|
||||
findings: any[] | null; recommendations: any[] | null;
|
||||
mitre_mapping: Record<string, any> | null;
|
||||
ioc_table: any[] | null; host_risk_summary: any[] | null;
|
||||
models_used: any[] | null; generation_time_ms: number | null;
|
||||
}
|
||||
|
||||
export interface AnomalyResultData {
|
||||
id: string; dataset_id: string; row_id: number | null;
|
||||
anomaly_score: number; distance_from_centroid: number | null;
|
||||
cluster_id: number | null; is_outlier: boolean;
|
||||
explanation: string | null;
|
||||
}
|
||||
|
||||
export interface HostGroupData {
|
||||
hostname: string;
|
||||
dataset_count: number;
|
||||
total_rows: number;
|
||||
artifact_types: string[];
|
||||
first_seen: string | null;
|
||||
last_seen: string | null;
|
||||
risk_score: number | null;
|
||||
}
|
||||
|
||||
// -- Job queue types (Phase 10) --
|
||||
|
||||
export interface JobData {
|
||||
id: string;
|
||||
job_type: string;
|
||||
status: 'queued' | 'running' | 'completed' | 'failed' | 'cancelled';
|
||||
progress: number;
|
||||
message: string;
|
||||
error: string | null;
|
||||
created_at: number;
|
||||
started_at: number | null;
|
||||
completed_at: number | null;
|
||||
elapsed_ms: number;
|
||||
params: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface JobStats {
|
||||
total: number;
|
||||
queued: number;
|
||||
by_status: Record<string, number>;
|
||||
workers: number;
|
||||
active_workers: number;
|
||||
}
|
||||
|
||||
export interface LBNodeStatus {
|
||||
healthy: boolean;
|
||||
active_jobs: number;
|
||||
total_completed: number;
|
||||
total_errors: number;
|
||||
avg_latency_ms: number;
|
||||
last_check: number;
|
||||
}
|
||||
|
||||
export const analysis = {
|
||||
// Triage
|
||||
triageResults: (datasetId: string, minRisk = 0) =>
|
||||
api<TriageResultData[]>(`/api/analysis/triage/${datasetId}?min_risk=${minRisk}`),
|
||||
triggerTriage: (datasetId: string) =>
|
||||
api<{ status: string; dataset_id: string }>(`/api/analysis/triage/${datasetId}`, { method: 'POST' }),
|
||||
|
||||
// Host profiles
|
||||
hostProfiles: (huntId: string, minRisk = 0) =>
|
||||
api<HostProfileData[]>(`/api/analysis/profiles/${huntId}?min_risk=${minRisk}`),
|
||||
triggerAllProfiles: (huntId: string) =>
|
||||
api<{ status: string; hunt_id: string }>(`/api/analysis/profiles/${huntId}`, { method: 'POST' }),
|
||||
triggerHostProfile: (huntId: string, hostname: string) =>
|
||||
api<{ status: string }>(`/api/analysis/profiles/${huntId}/${encodeURIComponent(hostname)}`, { method: 'POST' }),
|
||||
|
||||
// Reports
|
||||
listReports: (huntId: string) =>
|
||||
api<HuntReportData[]>(`/api/analysis/reports/${huntId}`),
|
||||
getReport: (huntId: string, reportId: string) =>
|
||||
api<HuntReportData>(`/api/analysis/reports/${huntId}/${reportId}`),
|
||||
generateReport: (huntId: string) =>
|
||||
api<{ status: string; hunt_id: string }>(`/api/analysis/reports/${huntId}/generate`, { method: 'POST' }),
|
||||
|
||||
// Anomaly detection
|
||||
anomalies: (datasetId: string, outliersOnly = false) =>
|
||||
api<AnomalyResultData[]>(`/api/analysis/anomalies/${datasetId}?outliers_only=${outliersOnly}`),
|
||||
triggerAnomalyDetection: (datasetId: string, k = 3, threshold = 0.35) =>
|
||||
api<{ status: string; dataset_id: string }>(
|
||||
`/api/analysis/anomalies/${datasetId}?k=${k}&threshold=${threshold}`, { method: 'POST' },
|
||||
),
|
||||
|
||||
// IOC extraction
|
||||
extractIocs: (datasetId: string) =>
|
||||
api<{ dataset_id: string; iocs: Record<string, string[]>; total: number }>(
|
||||
`/api/analysis/iocs/${datasetId}`,
|
||||
),
|
||||
|
||||
// Host grouping
|
||||
hostGroups: (huntId: string) =>
|
||||
api<{ hunt_id: string; hosts: HostGroupData[] }>(
|
||||
`/api/analysis/hosts/${huntId}`,
|
||||
),
|
||||
|
||||
// Data query (Phase 9) - SSE streaming
|
||||
queryStream: async (datasetId: string, question: string, mode: string = 'quick'): Promise<Response> => {
|
||||
const headers: Record<string, string> = { 'Content-Type': 'application/json' };
|
||||
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
|
||||
return fetch(`${BASE}/api/analysis/query/${datasetId}`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({ question, mode }),
|
||||
});
|
||||
},
|
||||
|
||||
// Data query (Phase 9) - sync
|
||||
querySync: (datasetId: string, question: string, mode: string = 'quick') =>
|
||||
api<{ dataset_id: string; question: string; answer: string; mode: string }>(
|
||||
`/api/analysis/query/${datasetId}/sync`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ question, mode }),
|
||||
},
|
||||
),
|
||||
|
||||
// Job queue (Phase 10)
|
||||
listJobs: (status?: string, jobType?: string, limit = 50) => {
|
||||
const q = new URLSearchParams();
|
||||
if (status) q.set('status', status);
|
||||
if (jobType) q.set('job_type', jobType);
|
||||
q.set('limit', String(limit));
|
||||
return api<{ jobs: JobData[]; stats: JobStats }>(`/api/analysis/jobs?${q}`);
|
||||
},
|
||||
getJob: (jobId: string) =>
|
||||
api<JobData>(`/api/analysis/jobs/${jobId}`),
|
||||
cancelJob: (jobId: string) =>
|
||||
api<{ status: string; job_id: string }>(`/api/analysis/jobs/${jobId}`, { method: 'DELETE' }),
|
||||
submitJob: (jobType: string, params: Record<string, any> = {}) =>
|
||||
api<{ job_id: string; status: string; job_type: string }>(
|
||||
`/api/analysis/jobs/submit/${jobType}`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(params),
|
||||
},
|
||||
),
|
||||
|
||||
// Load balancer (Phase 10)
|
||||
lbStatus: () =>
|
||||
api<Record<string, LBNodeStatus>>('/api/analysis/lb/status'),
|
||||
lbCheck: () =>
|
||||
api<Record<string, LBNodeStatus>>('/api/analysis/lb/check', { method: 'POST' }),
|
||||
};
|
||||
|
||||
// -- Network Topology --
|
||||
|
||||
export interface InventoryHost {
|
||||
id: string;
|
||||
hostname: string;
|
||||
fqdn: string;
|
||||
client_id: string;
|
||||
ips: string[];
|
||||
os: string;
|
||||
users: string[];
|
||||
datasets: string[];
|
||||
row_count: number;
|
||||
}
|
||||
|
||||
export interface InventoryConnection {
|
||||
source: string;
|
||||
target: string;
|
||||
target_ip: string;
|
||||
port: string;
|
||||
count: number;
|
||||
}
|
||||
|
||||
export interface InventoryStats {
|
||||
total_hosts: number;
|
||||
total_datasets_scanned: number;
|
||||
datasets_with_hosts: number;
|
||||
total_rows_scanned: number;
|
||||
hosts_with_ips: number;
|
||||
hosts_with_users: number;
|
||||
}
|
||||
|
||||
export interface HostInventory {
|
||||
hosts: InventoryHost[];
|
||||
connections: InventoryConnection[];
|
||||
stats: InventoryStats;
|
||||
}
|
||||
|
||||
export const network = {
|
||||
hostInventory: (huntId: string) =>
|
||||
api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}`),
|
||||
};
|
||||
818
frontend/src/components/AnalysisDashboard.tsx
Normal file
818
frontend/src/components/AnalysisDashboard.tsx
Normal file
@@ -0,0 +1,818 @@
|
||||
/**
|
||||
* AnalysisDashboard -- 6-tab view covering the full AI analysis pipeline:
|
||||
* 1. Triage results
|
||||
* 2. Host profiles
|
||||
* 3. Reports
|
||||
* 4. Anomalies
|
||||
* 5. Ask Data (natural language query with SSE streaming) -- Phase 9
|
||||
* 6. Jobs & Load Balancer status -- Phase 10
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import {
|
||||
Box, Typography, Tabs, Tab, Paper, Button, Chip, Stack, CircularProgress,
|
||||
Table, TableBody, TableCell, TableContainer, TableHead, TableRow,
|
||||
Accordion, AccordionSummary, AccordionDetails, Alert, Select, MenuItem,
|
||||
FormControl, InputLabel, LinearProgress, Tooltip, IconButton, Divider,
|
||||
Card, CardContent, CardActions, Grid, TextField, ToggleButton,
|
||||
ToggleButtonGroup,
|
||||
} from '@mui/material';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
import PlayArrowIcon from '@mui/icons-material/PlayArrow';
|
||||
import RefreshIcon from '@mui/icons-material/Refresh';
|
||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||
import SecurityIcon from '@mui/icons-material/Security';
|
||||
import PersonIcon from '@mui/icons-material/Person';
|
||||
import WarningAmberIcon from '@mui/icons-material/WarningAmber';
|
||||
import ShieldIcon from '@mui/icons-material/Shield';
|
||||
import BubbleChartIcon from '@mui/icons-material/BubbleChart';
|
||||
import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer';
|
||||
import WorkIcon from '@mui/icons-material/Work';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import StopIcon from '@mui/icons-material/Stop';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||
import ErrorIcon from '@mui/icons-material/Error';
|
||||
import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
|
||||
import CancelIcon from '@mui/icons-material/Cancel';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import {
|
||||
analysis, hunts, datasets,
|
||||
type Hunt, type DatasetSummary,
|
||||
type TriageResultData, type HostProfileData, type HuntReportData,
|
||||
type AnomalyResultData, type JobData, type JobStats, type LBNodeStatus,
|
||||
} from '../api/client';
|
||||
|
||||
/* helpers */
|
||||
|
||||
function riskColor(score: number): 'error' | 'warning' | 'info' | 'success' | 'default' {
|
||||
if (score >= 8) return 'error';
|
||||
if (score >= 5) return 'warning';
|
||||
if (score >= 2) return 'info';
|
||||
return 'success';
|
||||
}
|
||||
|
||||
function riskLabel(level: string): 'error' | 'warning' | 'info' | 'success' | 'default' {
|
||||
if (level === 'critical' || level === 'high') return 'error';
|
||||
if (level === 'medium') return 'warning';
|
||||
if (level === 'low') return 'success';
|
||||
return 'default';
|
||||
}
|
||||
|
||||
function fmtMs(ms: number): string {
|
||||
if (ms < 1000) return `${ms}ms`;
|
||||
return `${(ms / 1000).toFixed(1)}s`;
|
||||
}
|
||||
|
||||
function fmtTime(ts: number): string {
|
||||
if (!ts) return '--';
|
||||
return new Date(ts * 1000).toLocaleTimeString();
|
||||
}
|
||||
|
||||
const statusIcon = (s: string) => {
|
||||
switch (s) {
|
||||
case 'completed': return <CheckCircleIcon color="success" sx={{ fontSize: 18 }} />;
|
||||
case 'failed': return <ErrorIcon color="error" sx={{ fontSize: 18 }} />;
|
||||
case 'running': return <CircularProgress size={16} />;
|
||||
case 'queued': return <HourglassEmptyIcon color="action" sx={{ fontSize: 18 }} />;
|
||||
case 'cancelled': return <CancelIcon color="disabled" sx={{ fontSize: 18 }} />;
|
||||
default: return null;
|
||||
}
|
||||
};
|
||||
|
||||
/* TabPanel */
|
||||
|
||||
function TabPanel({ children, value, index }: { children: React.ReactNode; value: number; index: number }) {
|
||||
return value === index ? <Box sx={{ pt: 2 }}>{children}</Box> : null;
|
||||
}
|
||||
|
||||
/* Main component */
|
||||
|
||||
export default function AnalysisDashboard() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [tab, setTab] = useState(0);
|
||||
|
||||
// Selectors
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [dsList, setDsList] = useState<DatasetSummary[]>([]);
|
||||
const [huntId, setHuntId] = useState('');
|
||||
const [dsId, setDsId] = useState('');
|
||||
|
||||
// Data tabs 0-3
|
||||
const [triageResults, setTriageResults] = useState<TriageResultData[]>([]);
|
||||
const [profiles, setProfiles] = useState<HostProfileData[]>([]);
|
||||
const [reports, setReports] = useState<HuntReportData[]>([]);
|
||||
const [anomalies, setAnomalies] = useState<AnomalyResultData[]>([]);
|
||||
|
||||
// Loading states
|
||||
const [loadingTriage, setLoadingTriage] = useState(false);
|
||||
const [loadingProfiles, setLoadingProfiles] = useState(false);
|
||||
const [loadingReports, setLoadingReports] = useState(false);
|
||||
const [loadingAnomalies, setLoadingAnomalies] = useState(false);
|
||||
const [triggering, setTriggering] = useState(false);
|
||||
|
||||
// Phase 9: Ask Data
|
||||
const [queryText, setQueryText] = useState('');
|
||||
const [queryMode, setQueryMode] = useState<string>('quick');
|
||||
const [queryAnswer, setQueryAnswer] = useState('');
|
||||
const [queryStreaming, setQueryStreaming] = useState(false);
|
||||
const [queryMeta, setQueryMeta] = useState<Record<string, any> | null>(null);
|
||||
const [queryDone, setQueryDone] = useState<Record<string, any> | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const answerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Phase 10: Jobs
|
||||
const [jobs, setJobs] = useState<JobData[]>([]);
|
||||
const [jobStats, setJobStats] = useState<JobStats | null>(null);
|
||||
const [lbStatus, setLbStatus] = useState<Record<string, LBNodeStatus> | null>(null);
|
||||
const [loadingJobs, setLoadingJobs] = useState(false);
|
||||
|
||||
// Load hunts and datasets
|
||||
useEffect(() => {
|
||||
hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {});
|
||||
datasets.list(0, 200).then(r => setDsList(r.datasets)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (huntList.length > 0 && !huntId) setHuntId(huntList[0].id);
|
||||
}, [huntList, huntId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (dsList.length > 0 && !dsId) setDsId(dsList[0].id);
|
||||
}, [dsList, dsId]);
|
||||
|
||||
/* Fetch triage results */
|
||||
const fetchTriage = useCallback(async () => {
|
||||
if (!dsId) return;
|
||||
setLoadingTriage(true);
|
||||
try {
|
||||
const data = await analysis.triageResults(dsId);
|
||||
setTriageResults(data);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Triage fetch failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setLoadingTriage(false); }
|
||||
}, [dsId, enqueueSnackbar]);
|
||||
|
||||
const fetchProfiles = useCallback(async () => {
|
||||
if (!huntId) return;
|
||||
setLoadingProfiles(true);
|
||||
try {
|
||||
const data = await analysis.hostProfiles(huntId);
|
||||
setProfiles(data);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Profiles fetch failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setLoadingProfiles(false); }
|
||||
}, [huntId, enqueueSnackbar]);
|
||||
|
||||
const fetchReports = useCallback(async () => {
|
||||
if (!huntId) return;
|
||||
setLoadingReports(true);
|
||||
try {
|
||||
const data = await analysis.listReports(huntId);
|
||||
setReports(data);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Reports fetch failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setLoadingReports(false); }
|
||||
}, [huntId, enqueueSnackbar]);
|
||||
|
||||
const fetchAnomalies = useCallback(async () => {
|
||||
if (!dsId) return;
|
||||
setLoadingAnomalies(true);
|
||||
try {
|
||||
const data = await analysis.anomalies(dsId);
|
||||
setAnomalies(data);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar('Anomaly fetch failed: ' + e.message, { variant: 'error' });
|
||||
} finally { setLoadingAnomalies(false); }
|
||||
}, [dsId, enqueueSnackbar]);
|
||||
|
||||
const fetchJobs = useCallback(async () => {
|
||||
setLoadingJobs(true);
|
||||
try {
|
||||
const data = await analysis.listJobs();
|
||||
setJobs(data.jobs);
|
||||
setJobStats(data.stats);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar('Jobs fetch failed: ' + e.message, { variant: 'error' });
|
||||
} finally { setLoadingJobs(false); }
|
||||
}, [enqueueSnackbar]);
|
||||
|
||||
const fetchLbStatus = useCallback(async () => {
|
||||
try {
|
||||
const data = await analysis.lbStatus();
|
||||
setLbStatus(data);
|
||||
} catch {}
|
||||
}, []);
|
||||
|
||||
// Load data when selectors change
|
||||
useEffect(() => { if (dsId) fetchTriage(); }, [dsId, fetchTriage]);
|
||||
useEffect(() => { if (huntId) { fetchProfiles(); fetchReports(); } }, [huntId, fetchProfiles, fetchReports]);
|
||||
|
||||
// Auto-refresh jobs when on jobs tab
|
||||
useEffect(() => {
|
||||
if (tab === 5) {
|
||||
fetchJobs();
|
||||
fetchLbStatus();
|
||||
const iv = setInterval(() => { fetchJobs(); fetchLbStatus(); }, 5000);
|
||||
return () => clearInterval(iv);
|
||||
}
|
||||
}, [tab, fetchJobs, fetchLbStatus]);
|
||||
|
||||
/* Trigger actions */
|
||||
const doTriggerTriage = useCallback(async () => {
|
||||
if (!dsId) return;
|
||||
setTriggering(true);
|
||||
try {
|
||||
await analysis.triggerTriage(dsId);
|
||||
enqueueSnackbar('Triage started', { variant: 'info' });
|
||||
setTimeout(fetchTriage, 5000);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Triage trigger failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setTriggering(false); }
|
||||
}, [dsId, enqueueSnackbar, fetchTriage]);
|
||||
|
||||
const doTriggerProfiles = useCallback(async () => {
|
||||
if (!huntId) return;
|
||||
setTriggering(true);
|
||||
try {
|
||||
await analysis.triggerAllProfiles(huntId);
|
||||
enqueueSnackbar('Host profiling started', { variant: 'info' });
|
||||
setTimeout(fetchProfiles, 10000);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Profile trigger failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setTriggering(false); }
|
||||
}, [huntId, enqueueSnackbar, fetchProfiles]);
|
||||
|
||||
const doGenerateReport = useCallback(async () => {
|
||||
if (!huntId) return;
|
||||
setTriggering(true);
|
||||
try {
|
||||
await analysis.generateReport(huntId);
|
||||
enqueueSnackbar('Report generation started', { variant: 'info' });
|
||||
setTimeout(fetchReports, 15000);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(`Report generation failed: ${e.message}`, { variant: 'error' });
|
||||
} finally { setTriggering(false); }
|
||||
}, [huntId, enqueueSnackbar, fetchReports]);
|
||||
|
||||
const doTriggerAnomalies = useCallback(async () => {
|
||||
if (!dsId) return;
|
||||
setTriggering(true);
|
||||
try {
|
||||
await analysis.triggerAnomalyDetection(dsId);
|
||||
enqueueSnackbar('Anomaly detection started', { variant: 'info' });
|
||||
setTimeout(fetchAnomalies, 20000);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar('Anomaly trigger failed: ' + e.message, { variant: 'error' });
|
||||
} finally { setTriggering(false); }
|
||||
}, [dsId, enqueueSnackbar, fetchAnomalies]);
|
||||
|
||||
/* Phase 9: Streaming data query */
|
||||
const doQuery = useCallback(async () => {
|
||||
if (!dsId || !queryText.trim()) return;
|
||||
setQueryStreaming(true);
|
||||
setQueryAnswer('');
|
||||
setQueryMeta(null);
|
||||
setQueryDone(null);
|
||||
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
|
||||
try {
|
||||
const resp = await analysis.queryStream(dsId, queryText.trim(), queryMode);
|
||||
if (!resp.body) throw new Error('No response body');
|
||||
|
||||
const reader = resp.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buf = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buf += decoder.decode(value, { stream: true });
|
||||
|
||||
const lines = buf.split('\n');
|
||||
buf = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith('data: ')) continue;
|
||||
try {
|
||||
const evt = JSON.parse(line.slice(6));
|
||||
switch (evt.type) {
|
||||
case 'token':
|
||||
setQueryAnswer(prev => prev + evt.content);
|
||||
if (answerRef.current) {
|
||||
answerRef.current.scrollTop = answerRef.current.scrollHeight;
|
||||
}
|
||||
break;
|
||||
case 'metadata':
|
||||
setQueryMeta(evt.dataset);
|
||||
break;
|
||||
case 'done':
|
||||
setQueryDone(evt);
|
||||
break;
|
||||
case 'error':
|
||||
enqueueSnackbar(`Query error: ${evt.message}`, { variant: 'error' });
|
||||
break;
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (e.name !== 'AbortError') {
|
||||
enqueueSnackbar('Query failed: ' + e.message, { variant: 'error' });
|
||||
}
|
||||
} finally {
|
||||
setQueryStreaming(false);
|
||||
abortRef.current = null;
|
||||
}
|
||||
}, [dsId, queryText, queryMode, enqueueSnackbar]);
|
||||
|
||||
const stopQuery = useCallback(() => {
|
||||
if (abortRef.current) {
|
||||
abortRef.current.abort();
|
||||
setQueryStreaming(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
/* Phase 10: Cancel job */
|
||||
const doCancelJob = useCallback(async (jobId: string) => {
|
||||
try {
|
||||
await analysis.cancelJob(jobId);
|
||||
enqueueSnackbar('Job cancelled', { variant: 'info' });
|
||||
fetchJobs();
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar('Cancel failed: ' + e.message, { variant: 'error' });
|
||||
}
|
||||
}, [enqueueSnackbar, fetchJobs]);
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" spacing={2} sx={{ mb: 2 }}>
|
||||
<AssessmentIcon color="primary" sx={{ fontSize: 32 }} />
|
||||
<Typography variant="h5">AI Analysis</Typography>
|
||||
{triggering && <CircularProgress size={20} />}
|
||||
</Stack>
|
||||
|
||||
{/* Selectors */}
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Stack direction="row" spacing={2} flexWrap="wrap">
|
||||
<FormControl size="small" sx={{ minWidth: 260 }}>
|
||||
<InputLabel>Hunt</InputLabel>
|
||||
<Select label="Hunt" value={huntId} onChange={e => setHuntId(e.target.value)}>
|
||||
{huntList.map(h => (
|
||||
<MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl size="small" sx={{ minWidth: 260 }}>
|
||||
<InputLabel>Dataset</InputLabel>
|
||||
<Select label="Dataset" value={dsId} onChange={e => setDsId(e.target.value)}>
|
||||
{dsList.map(d => (
|
||||
<MenuItem key={d.id} value={d.id}>{d.name} ({d.row_count} rows)</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{/* Tabs */}
|
||||
<Tabs value={tab} onChange={(_, v) => setTab(v)} variant="scrollable" scrollButtons="auto" sx={{ mb: 1 }}>
|
||||
<Tab icon={<SecurityIcon />} iconPosition="start" label={`Triage (${triageResults.length})`} />
|
||||
<Tab icon={<PersonIcon />} iconPosition="start" label={`Host Profiles (${profiles.length})`} />
|
||||
<Tab icon={<AssessmentIcon />} iconPosition="start" label={`Reports (${reports.length})`} />
|
||||
<Tab icon={<BubbleChartIcon />} iconPosition="start" label={`Anomalies (${anomalies.filter(a => a.is_outlier).length})`} />
|
||||
<Tab icon={<QuestionAnswerIcon />} iconPosition="start" label="Ask Data" />
|
||||
<Tab icon={<WorkIcon />} iconPosition="start" label={`Jobs${jobStats ? ` (${jobStats.active_workers})` : ''}`} />
|
||||
</Tabs>
|
||||
|
||||
{/* Tab 0: Triage */}
|
||||
<TabPanel value={tab} index={0}>
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Button variant="contained" startIcon={<PlayArrowIcon />} onClick={doTriggerTriage}
|
||||
disabled={!dsId || triggering} size="small">Run Triage</Button>
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchTriage}
|
||||
disabled={!dsId || loadingTriage} size="small">Refresh</Button>
|
||||
</Stack>
|
||||
{loadingTriage && <LinearProgress sx={{ mb: 1 }} />}
|
||||
{triageResults.length === 0 && !loadingTriage ? (
|
||||
<Alert severity="info">No triage results yet. Select a dataset and click "Run Triage".</Alert>
|
||||
) : (
|
||||
<TableContainer component={Paper} sx={{ maxHeight: 500 }}>
|
||||
<Table size="small" stickyHeader>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Rows</TableCell><TableCell>Risk</TableCell><TableCell>Verdict</TableCell>
|
||||
<TableCell>Findings</TableCell><TableCell>MITRE</TableCell><TableCell>Model</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{triageResults.map(tr => (
|
||||
<TableRow key={tr.id} hover>
|
||||
<TableCell>{tr.row_start}-{tr.row_end}</TableCell>
|
||||
<TableCell><Chip label={tr.risk_score.toFixed(1)} size="small" color={riskColor(tr.risk_score)} /></TableCell>
|
||||
<TableCell><Chip label={tr.verdict} size="small" variant="outlined" /></TableCell>
|
||||
<TableCell sx={{ maxWidth: 300, overflow: 'hidden', textOverflow: 'ellipsis' }}>
|
||||
{tr.findings?.join('; ') || ''}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap">
|
||||
{tr.mitre_techniques?.map((t: string, i: number) => (
|
||||
<Chip key={i} label={t} size="small" variant="outlined" color="warning" />
|
||||
))}
|
||||
</Stack>
|
||||
</TableCell>
|
||||
<TableCell><Typography variant="caption">{tr.model_used || ''}</Typography></TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
)}
|
||||
</TabPanel>
|
||||
|
||||
{/* Tab 1: Host Profiles */}
|
||||
<TabPanel value={tab} index={1}>
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Button variant="contained" startIcon={<PlayArrowIcon />} onClick={doTriggerProfiles}
|
||||
disabled={!huntId || triggering} size="small">Profile All Hosts</Button>
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchProfiles}
|
||||
disabled={!huntId || loadingProfiles} size="small">Refresh</Button>
|
||||
</Stack>
|
||||
{loadingProfiles && <LinearProgress sx={{ mb: 1 }} />}
|
||||
{profiles.length === 0 && !loadingProfiles ? (
|
||||
<Alert severity="info">No host profiles yet. Select a hunt and click "Profile All Hosts".</Alert>
|
||||
) : (
|
||||
<Grid container spacing={2}>
|
||||
{profiles.map(hp => (
|
||||
<Grid size={{ xs: 12, md: 6, lg: 4 }} key={hp.id}>
|
||||
<Card variant="outlined" sx={{
|
||||
borderLeft: 4,
|
||||
borderLeftColor: hp.risk_level === 'critical' ? 'error.main'
|
||||
: hp.risk_level === 'high' ? 'error.light'
|
||||
: hp.risk_level === 'medium' ? 'warning.main' : 'success.main',
|
||||
}}>
|
||||
<CardContent>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center">
|
||||
<Typography variant="h6">{hp.hostname}</Typography>
|
||||
<Chip label={`${hp.risk_score.toFixed(1)} ${hp.risk_level}`}
|
||||
size="small" color={riskLabel(hp.risk_level)} />
|
||||
</Stack>
|
||||
{hp.fqdn && <Typography variant="caption" color="text.secondary">{hp.fqdn}</Typography>}
|
||||
<Divider sx={{ my: 1 }} />
|
||||
{hp.timeline_summary && (
|
||||
<Typography variant="body2" sx={{ mb: 1, whiteSpace: 'pre-wrap' }}>
|
||||
{hp.timeline_summary.slice(0, 300)}{hp.timeline_summary.length > 300 ? '...' : ''}
|
||||
</Typography>
|
||||
)}
|
||||
{hp.suspicious_findings && hp.suspicious_findings.length > 0 && (
|
||||
<Box sx={{ mb: 1 }}>
|
||||
<Typography variant="caption" color="warning.main">
|
||||
<WarningAmberIcon sx={{ fontSize: 14, mr: 0.5, verticalAlign: 'middle' }} />
|
||||
{hp.suspicious_findings.length} suspicious finding(s)
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
{hp.mitre_techniques && hp.mitre_techniques.length > 0 && (
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 1 }}>
|
||||
{hp.mitre_techniques.map((t: string, i: number) => (
|
||||
<Chip key={i} label={t} size="small" variant="outlined" color="warning" />
|
||||
))}
|
||||
</Stack>
|
||||
)}
|
||||
</CardContent>
|
||||
<CardActions>
|
||||
<Typography variant="caption" color="text.secondary">Model: {hp.model_used || 'N/A'}</Typography>
|
||||
</CardActions>
|
||||
</Card>
|
||||
</Grid>
|
||||
))}
|
||||
</Grid>
|
||||
)}
|
||||
</TabPanel>
|
||||
|
||||
{/* Tab 2: Reports */}
|
||||
<TabPanel value={tab} index={2}>
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Button variant="contained" startIcon={<PlayArrowIcon />} onClick={doGenerateReport}
|
||||
disabled={!huntId || triggering} size="small">Generate Report</Button>
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchReports}
|
||||
disabled={!huntId || loadingReports} size="small">Refresh</Button>
|
||||
</Stack>
|
||||
{loadingReports && <LinearProgress sx={{ mb: 1 }} />}
|
||||
{reports.length === 0 && !loadingReports ? (
|
||||
<Alert severity="info">No reports yet. Select a hunt and click "Generate Report".</Alert>
|
||||
) : (
|
||||
reports.map(rpt => (
|
||||
<Accordion key={rpt.id} defaultExpanded={reports.length === 1}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Stack direction="row" spacing={1} alignItems="center">
|
||||
<ShieldIcon color="primary" />
|
||||
<Typography>Report - {rpt.status}</Typography>
|
||||
{rpt.generation_time_ms && (
|
||||
<Chip label={fmtMs(rpt.generation_time_ms)} size="small" variant="outlined" />
|
||||
)}
|
||||
</Stack>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{rpt.exec_summary && (
|
||||
<Box sx={{ mb: 2 }}>
|
||||
<Typography variant="subtitle2" color="primary">Executive Summary</Typography>
|
||||
<Typography variant="body2" sx={{ whiteSpace: 'pre-wrap' }}>{rpt.exec_summary}</Typography>
|
||||
</Box>
|
||||
)}
|
||||
{rpt.findings && rpt.findings.length > 0 && (
|
||||
<Box sx={{ mb: 2 }}>
|
||||
<Typography variant="subtitle2" color="warning.main">Findings</Typography>
|
||||
<ul style={{ margin: 0, paddingLeft: 20 }}>
|
||||
{rpt.findings.map((f: any, i: number) => (
|
||||
<li key={i}><Typography variant="body2">
|
||||
{typeof f === 'string' ? f : JSON.stringify(f)}
|
||||
</Typography></li>
|
||||
))}
|
||||
</ul>
|
||||
</Box>
|
||||
)}
|
||||
{rpt.recommendations && rpt.recommendations.length > 0 && (
|
||||
<Box sx={{ mb: 2 }}>
|
||||
<Typography variant="subtitle2" color="success.main">Recommendations</Typography>
|
||||
<ul style={{ margin: 0, paddingLeft: 20 }}>
|
||||
{rpt.recommendations.map((r: any, i: number) => (
|
||||
<li key={i}><Typography variant="body2">
|
||||
{typeof r === 'string' ? r : JSON.stringify(r)}
|
||||
</Typography></li>
|
||||
))}
|
||||
</ul>
|
||||
</Box>
|
||||
)}
|
||||
{rpt.ioc_table && rpt.ioc_table.length > 0 && (
|
||||
<Box sx={{ mb: 2 }}>
|
||||
<Typography variant="subtitle2">IOC Table</Typography>
|
||||
<TableContainer component={Paper} variant="outlined" sx={{ maxHeight: 300 }}>
|
||||
<Table size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
{Object.keys(rpt.ioc_table[0]).map(k => (
|
||||
<TableCell key={k}>{k}</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{rpt.ioc_table.map((row: any, i: number) => (
|
||||
<TableRow key={i}>
|
||||
{Object.values(row).map((v: any, j: number) => (
|
||||
<TableCell key={j}><Typography variant="caption">{String(v)}</Typography></TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Box>
|
||||
)}
|
||||
{rpt.full_report && (
|
||||
<Accordion>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography variant="body2">Full Report</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
<Typography variant="body2" sx={{ whiteSpace: 'pre-wrap', fontFamily: 'monospace', fontSize: 12 }}>
|
||||
{rpt.full_report}
|
||||
</Typography>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||
{rpt.models_used?.map((m: string, i: number) => (
|
||||
<Chip key={i} label={m} size="small" variant="outlined" />
|
||||
))}
|
||||
</Stack>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
))
|
||||
)}
|
||||
</TabPanel>
|
||||
|
||||
{/* Tab 3: Anomalies */}
|
||||
<TabPanel value={tab} index={3}>
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Button variant="contained" startIcon={<PlayArrowIcon />} onClick={doTriggerAnomalies}
|
||||
disabled={!dsId || triggering} size="small">Detect Anomalies</Button>
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchAnomalies}
|
||||
disabled={!dsId || loadingAnomalies} size="small">Refresh</Button>
|
||||
</Stack>
|
||||
{loadingAnomalies && <LinearProgress sx={{ mb: 1 }} />}
|
||||
{anomalies.length === 0 && !loadingAnomalies ? (
|
||||
<Alert severity="info">No anomaly results yet. Select a dataset and click "Detect Anomalies".</Alert>
|
||||
) : (
|
||||
<>
|
||||
<Alert severity="warning" sx={{ mb: 1 }}>
|
||||
{anomalies.filter(a => a.is_outlier).length} outlier(s) detected out of {anomalies.length} rows
|
||||
</Alert>
|
||||
<TableContainer component={Paper} sx={{ maxHeight: 500 }}>
|
||||
<Table size="small" stickyHeader>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Row</TableCell><TableCell>Score</TableCell>
|
||||
<TableCell>Distance</TableCell><TableCell>Cluster</TableCell><TableCell>Outlier</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{anomalies.filter(a => a.is_outlier).concat(anomalies.filter(a => !a.is_outlier).slice(0, 20)).map((a, i) => (
|
||||
<TableRow key={a.id || i} hover sx={a.is_outlier ? { bgcolor: 'rgba(244,63,94,0.08)' } : {}}>
|
||||
<TableCell>{a.row_id ?? ''}</TableCell>
|
||||
<TableCell>
|
||||
<Chip label={a.anomaly_score.toFixed(4)} size="small"
|
||||
color={a.anomaly_score > 0.5 ? 'error' : a.anomaly_score > 0.35 ? 'warning' : 'success'} />
|
||||
</TableCell>
|
||||
<TableCell>{a.distance_from_centroid?.toFixed(4) ?? ''}</TableCell>
|
||||
<TableCell><Chip label={`C${a.cluster_id}`} size="small" variant="outlined" /></TableCell>
|
||||
<TableCell>
|
||||
{a.is_outlier
|
||||
? <Chip label="OUTLIER" size="small" color="error" />
|
||||
: <Chip label="Normal" size="small" color="success" variant="outlined" />}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</>
|
||||
)}
|
||||
</TabPanel>
|
||||
|
||||
{/* Tab 4: Ask Data (Phase 9) */}
|
||||
<TabPanel value={tab} index={4}>
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1 }}>
|
||||
Ask a question about the selected dataset in plain English
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={1} alignItems="flex-end">
|
||||
<TextField
|
||||
fullWidth size="small" multiline maxRows={3}
|
||||
placeholder="e.g., Are there any suspicious processes running at unusual hours?"
|
||||
value={queryText}
|
||||
onChange={e => setQueryText(e.target.value)}
|
||||
onKeyDown={e => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); doQuery(); } }}
|
||||
disabled={queryStreaming}
|
||||
/>
|
||||
<ToggleButtonGroup
|
||||
value={queryMode} exclusive size="small"
|
||||
onChange={(_, v) => { if (v) setQueryMode(v); }}
|
||||
>
|
||||
<ToggleButton value="quick">
|
||||
<Tooltip title="Fast (Roadrunner)"><Typography variant="caption">Quick</Typography></Tooltip>
|
||||
</ToggleButton>
|
||||
<ToggleButton value="deep">
|
||||
<Tooltip title="Deep (Wile 70B)"><Typography variant="caption">Deep</Typography></Tooltip>
|
||||
</ToggleButton>
|
||||
</ToggleButtonGroup>
|
||||
{queryStreaming ? (
|
||||
<IconButton color="error" onClick={stopQuery}><StopIcon /></IconButton>
|
||||
) : (
|
||||
<IconButton color="primary" onClick={doQuery} disabled={!dsId || !queryText.trim()}>
|
||||
<SendIcon />
|
||||
</IconButton>
|
||||
)}
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{queryMeta && (
|
||||
<Alert severity="info" sx={{ mb: 1 }}>
|
||||
Querying <strong>{queryMeta.name}</strong> ({queryMeta.row_count} rows,{' '}
|
||||
{queryMeta.sample_rows_shown} sampled) | Mode: {queryMode}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{queryStreaming && <LinearProgress sx={{ mb: 1 }} />}
|
||||
|
||||
{queryAnswer && (
|
||||
<Paper
|
||||
ref={answerRef}
|
||||
sx={{
|
||||
p: 2, maxHeight: 500, overflow: 'auto',
|
||||
bgcolor: 'grey.900', color: 'grey.100',
|
||||
fontFamily: 'monospace', fontSize: 13, whiteSpace: 'pre-wrap',
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
{queryAnswer}
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{queryDone && (
|
||||
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||
<Chip label={`${queryDone.tokens} tokens`} size="small" variant="outlined" />
|
||||
<Chip label={fmtMs(queryDone.elapsed_ms)} size="small" variant="outlined" />
|
||||
<Chip label={queryDone.model} size="small" />
|
||||
<Chip label={queryDone.node} size="small" color={queryDone.node === 'wile' ? 'secondary' : 'primary'} />
|
||||
</Stack>
|
||||
)}
|
||||
</TabPanel>
|
||||
|
||||
{/* Tab 5: Jobs & Load Balancer (Phase 10) */}
|
||||
<TabPanel value={tab} index={5}>
|
||||
{/* LB Status Cards */}
|
||||
{lbStatus && (
|
||||
<Grid container spacing={2} sx={{ mb: 2 }}>
|
||||
{Object.entries(lbStatus).map(([name, st]) => (
|
||||
<Grid size={{ xs: 12, sm: 6 }} key={name}>
|
||||
<Card variant="outlined" sx={{
|
||||
borderLeft: 4,
|
||||
borderLeftColor: st.healthy ? 'success.main' : 'error.main',
|
||||
}}>
|
||||
<CardContent sx={{ py: 1.5, '&:last-child': { pb: 1.5 } }}>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center">
|
||||
<Typography variant="h6" sx={{ textTransform: 'capitalize' }}>{name}</Typography>
|
||||
<Chip label={st.healthy ? 'HEALTHY' : 'DOWN'} size="small"
|
||||
color={st.healthy ? 'success' : 'error'} />
|
||||
</Stack>
|
||||
<Stack direction="row" spacing={2} sx={{ mt: 1 }}>
|
||||
<Typography variant="body2">Active: <strong>{st.active_jobs}</strong></Typography>
|
||||
<Typography variant="body2">Done: <strong>{st.total_completed}</strong></Typography>
|
||||
<Typography variant="body2">Errors: <strong>{st.total_errors}</strong></Typography>
|
||||
<Typography variant="body2">Avg: <strong>{st.avg_latency_ms.toFixed(0)}ms</strong></Typography>
|
||||
</Stack>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</Grid>
|
||||
))}
|
||||
</Grid>
|
||||
)}
|
||||
|
||||
{/* Job queue stats */}
|
||||
{jobStats && (
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Chip label={`Workers: ${jobStats.active_workers}/${jobStats.workers}`} size="small" />
|
||||
<Chip label={`Queued: ${jobStats.queued}`} size="small" color="info" />
|
||||
{Object.entries(jobStats.by_status).map(([s, c]) => (
|
||||
<Chip key={s} label={`${s}: ${c}`} size="small" variant="outlined" />
|
||||
))}
|
||||
</Stack>
|
||||
)}
|
||||
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }}>
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />} onClick={fetchJobs}
|
||||
disabled={loadingJobs} size="small">Refresh</Button>
|
||||
</Stack>
|
||||
|
||||
{loadingJobs && <LinearProgress sx={{ mb: 1 }} />}
|
||||
|
||||
{jobs.length === 0 && !loadingJobs ? (
|
||||
<Alert severity="info">No jobs yet. Jobs appear here when you trigger triage, profiling, reports, anomaly detection, or data queries.</Alert>
|
||||
) : (
|
||||
<TableContainer component={Paper} sx={{ maxHeight: 500 }}>
|
||||
<Table size="small" stickyHeader>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Status</TableCell>
|
||||
<TableCell>Type</TableCell>
|
||||
<TableCell>Progress</TableCell>
|
||||
<TableCell>Message</TableCell>
|
||||
<TableCell>Time</TableCell>
|
||||
<TableCell>Created</TableCell>
|
||||
<TableCell>Actions</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{jobs.map(j => (
|
||||
<TableRow key={j.id} hover
|
||||
sx={j.status === 'failed' ? { bgcolor: 'rgba(244,63,94,0.06)' }
|
||||
: j.status === 'running' ? { bgcolor: 'rgba(59,130,246,0.06)' } : {}}>
|
||||
<TableCell>
|
||||
<Stack direction="row" spacing={0.5} alignItems="center">
|
||||
{statusIcon(j.status)}
|
||||
<Typography variant="caption">{j.status}</Typography>
|
||||
</Stack>
|
||||
</TableCell>
|
||||
<TableCell><Chip label={j.job_type} size="small" variant="outlined" /></TableCell>
|
||||
<TableCell>
|
||||
{j.status === 'running' ? (
|
||||
<LinearProgress variant="determinate" value={j.progress}
|
||||
sx={{ width: 80, height: 6, borderRadius: 3 }} />
|
||||
) : j.status === 'completed' ? (
|
||||
<Typography variant="caption" color="success.main">100%</Typography>
|
||||
) : null}
|
||||
</TableCell>
|
||||
<TableCell sx={{ maxWidth: 200, overflow: 'hidden', textOverflow: 'ellipsis' }}>
|
||||
<Typography variant="caption">{j.error || j.message}</Typography>
|
||||
</TableCell>
|
||||
<TableCell><Typography variant="caption">{fmtMs(j.elapsed_ms)}</Typography></TableCell>
|
||||
<TableCell><Typography variant="caption">{fmtTime(j.created_at)}</Typography></TableCell>
|
||||
<TableCell>
|
||||
{(j.status === 'queued' || j.status === 'running') && (
|
||||
<IconButton size="small" color="error" onClick={() => doCancelJob(j.id)}>
|
||||
<CancelIcon fontSize="small" />
|
||||
</IconButton>
|
||||
)}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
)}
|
||||
</TabPanel>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -144,6 +144,12 @@ export default function DatasetViewer() {
|
||||
<Chip label={`${selected.row_count} rows`} size="small" />
|
||||
<Chip label={selected.encoding || 'utf-8'} size="small" variant="outlined" />
|
||||
{selected.source_tool && <Chip label={selected.source_tool} size="small" color="info" variant="outlined" />}
|
||||
{selected.artifact_type && <Chip label={selected.artifact_type} size="small" color="secondary" />}
|
||||
{selected.processing_status && selected.processing_status !== 'ready' && (
|
||||
<Chip label={selected.processing_status} size="small"
|
||||
color={selected.processing_status === 'done' ? 'success' : selected.processing_status === 'error' ? 'error' : 'warning'}
|
||||
variant="outlined" />
|
||||
)}
|
||||
{selected.ioc_columns && Object.keys(selected.ioc_columns).length > 0 && (
|
||||
<Chip label={`${Object.keys(selected.ioc_columns).length} IOC columns`} size="small" color="warning" variant="outlined" />
|
||||
)}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
41
update.md
Normal file
41
update.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# ThreatHunt Update Log
|
||||
|
||||
## 2026-02-20: Host-Centric Network Map & Analysis Platform
|
||||
|
||||
### Network Map Overhaul
|
||||
- **Problem**: Network Map showed 409 misclassified "domain" nodes (mostly process names like svchost.exe) and 0 hosts. No deduplication same host counted once per dataset.
|
||||
- **Root Cause**: IOC column detection misclassified `Fqdn` as "domain" instead of "hostname"; `Name` column (process names) wrongly tagged as "domain" IOC; `ClientId` was in `normalized_columns` as "hostname" but not in `ioc_columns`.
|
||||
- **Solution**: Created a new host-centric inventory system that scans all datasets, groups by `Fqdn`/`ClientId`, and extracts IPs, users, OS, and network connections.
|
||||
|
||||
#### New Backend Files
|
||||
- `backend/app/services/host_inventory.py` Deduplicated host inventory builder. Scans all datasets in a hunt, identifies unique hosts via regex-based column detection (`ClientId`, `Fqdn`, `User`/`Username`, `Laddr.IP`/`Raddr.IP`), groups rows, extracts metadata. Filters system accounts (DWM-*, UMFD-*, LOCAL SERVICE, NETWORK SERVICE). Infers OS from hostname patterns (W10-* Windows 10). Builds network connection graph from netstat remote IPs.
|
||||
- `backend/app/api/routes/network.py` `GET /api/network/host-inventory?hunt_id=X` endpoint returning `{hosts, connections, stats}`.
|
||||
- `backend/app/services/ioc_extractor.py` IOC extraction service (IP, domain, hash, email, URL patterns).
|
||||
- `backend/app/services/anomaly_detector.py` Statistical anomaly detection across datasets.
|
||||
- `backend/app/services/data_query.py` Natural language to structured query translation.
|
||||
- `backend/app/services/load_balancer.py` Round-robin load balancer for Ollama LLM nodes.
|
||||
- `backend/app/services/job_queue.py` Async job queue for long-running analysis tasks.
|
||||
- `backend/app/api/routes/analysis.py` 16 analysis endpoints (IOC extraction, anomaly detection, host profiling, triage, reports, job management).
|
||||
|
||||
#### Modified Backend Files
|
||||
- `backend/app/main.py` Added `network_router` and `analysis_router` includes.
|
||||
- `backend/app/db/models.py` Added 4 AI/analysis ORM models (`ProcessingJob`, `AnalysisResult`, `HostProfile`, `IOCEntry`).
|
||||
- `backend/app/db/engine.py` Connection pool tuning for SQLite async.
|
||||
|
||||
#### Frontend Changes
|
||||
- `frontend/src/components/NetworkMap.tsx` Complete rewrite: host-centric force-directed graph using Canvas 2D. Two node types (Host / External IP). Shows hostname, IP, OS in labels. Click popover shows FQDN, IPs, OS, logged-in users, datasets, connections. Search across hostname/IP/user/OS. Stats cards showing host counts.
|
||||
- `frontend/src/components/AnalysisDashboard.tsx` New 6-tab analysis dashboard (IOC Extraction, Anomaly Detection, Host Profiling, Query, Triage, Reports).
|
||||
- `frontend/src/api/client.ts` Added `network.hostInventory()` method + `InventoryHost`, `InventoryConnection`, `InventoryStats` types. Added analysis API namespace with 16 endpoint methods.
|
||||
- `frontend/src/App.tsx` Added Analysis Dashboard route and navigation.
|
||||
|
||||
### Results (Radio Hunt 20 Velociraptor datasets, 394K rows)
|
||||
|
||||
| Metric | Before | After |
|
||||
|--------|--------|-------|
|
||||
| Nodes shown | 409 misclassified "domains" | **163 unique hosts** |
|
||||
| Hosts identified | 0 | **163** |
|
||||
| With IP addresses | N/A | **48** (172.17.x.x LAN) |
|
||||
| With logged-in users | N/A | **43** (real names only) |
|
||||
| OS detected | None | **Windows 10** (inferred from hostnames) |
|
||||
| Deduplication | None (same host 20 datasets) | **Full** (by FQDN/ClientId) |
|
||||
| System account filtering | None | **DWM-*, UMFD-*, LOCAL/NETWORK SERVICE removed** |
|
||||
Reference in New Issue
Block a user