Files
ThreatHunt/backend/app/services/data_query.py
mblanke 04a9946891 feat: host-centric network map, analysis dashboard, deduped inventory
- Rewrote NetworkMap to use deduplicated host inventory (163 hosts from 394K rows)
- New host_inventory.py service: scans datasets, groups by FQDN/ClientId, extracts IPs/users/OS
- New /api/network/host-inventory endpoint
- Added AnalysisDashboard with 6 tabs (IOC, anomaly, host profile, query, triage, reports)
- Added 16 analysis API endpoints with job queue and load balancer
- Added 4 AI/analysis ORM models (ProcessingJob, AnalysisResult, HostProfile, IOCEntry)
- Filters system accounts (DWM-*, UMFD-*, LOCAL/NETWORK SERVICE)
- Infers OS from hostname patterns (W10-* -> Windows 10)
- Canvas 2D force-directed graph with host/external-IP node types
- Click popover shows hostname, FQDN, IPs, OS, users, datasets, connections
2026-02-20 07:16:17 -05:00

238 lines
7.6 KiB
Python

"""Natural-language data query service with SSE streaming.
Lets analysts ask questions about dataset rows in plain English.
Routes to fast model (Roadrunner) for quick queries, heavy model (Wile)
for deep analysis. Supports streaming via OllamaProvider.generate_stream().
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import AsyncIterator
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import async_session_factory
from app.db.models import Dataset, DatasetRow
logger = logging.getLogger(__name__)
# Maximum rows to include in context window
MAX_CONTEXT_ROWS = 60
MAX_ROW_TEXT_CHARS = 300
def _rows_to_text(rows: list[dict], columns: list[str]) -> str:
"""Convert dataset rows to a compact text table for the LLM context."""
if not rows:
return "(no rows)"
# Header
header = " | ".join(columns[:20]) # cap columns to avoid overflow
lines = [header, "-" * min(len(header), 120)]
for row in rows[:MAX_CONTEXT_ROWS]:
vals = []
for c in columns[:20]:
v = str(row.get(c, ""))
if len(v) > 80:
v = v[:77] + "..."
vals.append(v)
line = " | ".join(vals)
if len(line) > MAX_ROW_TEXT_CHARS:
line = line[:MAX_ROW_TEXT_CHARS] + "..."
lines.append(line)
return "\n".join(lines)
QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt.
You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.).
Your job:
- Answer the analyst's question about this data accurately and concisely
- Point out suspicious patterns, anomalies, or indicators of compromise
- Reference MITRE ATT&CK techniques when relevant
- Suggest follow-up queries or pivots
- If you cannot answer from the data provided, say so clearly
Rules:
- Be factual - only reference data you can see
- Use forensic terminology appropriate for SOC/DFIR analysts
- Format your answer with clear sections using markdown
- If the data seems benign, say so - do not fabricate threats"""
async def _load_dataset_context(
dataset_id: str,
db: AsyncSession,
sample_size: int = MAX_CONTEXT_ROWS,
) -> tuple[dict, str, int]:
"""Load dataset metadata + sample rows for context.
Returns (metadata_dict, rows_text, total_row_count).
"""
ds = await db.get(Dataset, dataset_id)
if not ds:
raise ValueError(f"Dataset {dataset_id} not found")
# Get total count
count_q = await db.execute(
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
)
total = count_q.scalar() or 0
# Sample rows - get first batch + some from the middle
half = sample_size // 2
result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == dataset_id)
.order_by(DatasetRow.row_index)
.limit(half)
)
first_rows = result.scalars().all()
# If dataset is large, also sample from the middle
middle_rows = []
if total > sample_size:
mid_offset = total // 2
result2 = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == dataset_id)
.order_by(DatasetRow.row_index)
.offset(mid_offset)
.limit(sample_size - half)
)
middle_rows = result2.scalars().all()
else:
result2 = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == dataset_id)
.order_by(DatasetRow.row_index)
.offset(half)
.limit(sample_size - half)
)
middle_rows = result2.scalars().all()
all_rows = first_rows + middle_rows
row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows]
columns = list(ds.column_schema.keys()) if ds.column_schema else []
if not columns and row_dicts:
columns = list(row_dicts[0].keys())
rows_text = _rows_to_text(row_dicts, columns)
metadata = {
"name": ds.name,
"filename": ds.filename,
"source_tool": ds.source_tool,
"artifact_type": getattr(ds, "artifact_type", None),
"row_count": total,
"columns": columns[:30],
"sample_rows_shown": len(all_rows),
}
return metadata, rows_text, total
async def query_dataset(
dataset_id: str,
question: str,
mode: str = "quick",
) -> str:
"""Non-streaming query: returns full answer text."""
from app.agents.providers_v2 import OllamaProvider, Node
async with async_session_factory() as db:
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
prompt = _build_prompt(question, meta, rows_text, total)
if mode == "deep":
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
max_tokens = 4096
else:
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
max_tokens = 2048
result = await provider.generate(
prompt,
system=QUERY_SYSTEM_PROMPT,
max_tokens=max_tokens,
temperature=0.3,
)
return result.get("response", "No response generated.")
async def query_dataset_stream(
dataset_id: str,
question: str,
mode: str = "quick",
) -> AsyncIterator[str]:
"""Streaming query: yields SSE-formatted events."""
from app.agents.providers_v2 import OllamaProvider, Node
start = time.monotonic()
# Send initial metadata event
yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n"
async with async_session_factory() as db:
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n"
yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n"
prompt = _build_prompt(question, meta, rows_text, total)
if mode == "deep":
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
max_tokens = 4096
model_name = settings.DEFAULT_HEAVY_MODEL
node_name = "wile"
else:
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
max_tokens = 2048
model_name = settings.DEFAULT_FAST_MODEL
node_name = "roadrunner"
# Stream tokens
token_count = 0
try:
async for token in provider.generate_stream(
prompt,
system=QUERY_SYSTEM_PROMPT,
max_tokens=max_tokens,
temperature=0.3,
):
token_count += 1
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
elapsed_ms = int((time.monotonic() - start) * 1000)
yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n"
def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str:
"""Construct the full prompt with data context."""
parts = [
f"## Dataset: {meta['name']}",
f"- Source: {meta.get('source_tool', 'unknown')}",
f"- Artifact type: {meta.get('artifact_type', 'unknown')}",
f"- Total rows: {total}",
f"- Columns: {', '.join(meta.get('columns', []))}",
f"- Showing {meta['sample_rows_shown']} sample rows below",
"",
"## Sample Data",
"```",
rows_text,
"```",
"",
f"## Analyst Question",
question,
]
return "\n".join(parts)