feat: interactive network map, IOC highlighting, AUP hunt selector, type filters

- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover
- NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform
- NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types
- NetworkMap: brighter colors, 20% smaller nodes
- DatasetViewer: IOC columns highlighted with colored headers + cell tinting
- AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all
- Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration
- Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade)
- Add OS column mapping to normalizer
- Full backend services, DB models, alembic migrations, new routes
- New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc.
- Docker Compose deployment with nginx reverse proxy
This commit is contained in:
2026-02-19 15:41:15 -05:00
parent d0c9f88268
commit 9b98ab9614
92 changed files with 13042 additions and 1089 deletions

View File

@@ -0,0 +1 @@
"""Services package."""

View File

@@ -0,0 +1,201 @@
"""Authentication & security — JWT tokens, password hashing, role-based access.
Provides:
- Password hashing (bcrypt via passlib)
- JWT access/refresh token creation and verification
- FastAPI dependency for protecting routes
- Role-based enforcement (analyst, admin, viewer)
"""
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError, jwt
from passlib.context import CryptContext
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.db.models import User
logger = logging.getLogger(__name__)
# ── Password hashing ─────────────────────────────────────────────────
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
# ── JWT tokens ────────────────────────────────────────────────────────
ALGORITHM = "HS256"
security = HTTPBearer(auto_error=False)
class TokenPair(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int # seconds
class TokenPayload(BaseModel):
sub: str # user_id
role: str
exp: datetime
type: str # "access" or "refresh"
def create_access_token(user_id: str, role: str) -> str:
expires = datetime.now(timezone.utc) + timedelta(
minutes=settings.JWT_ACCESS_TOKEN_MINUTES
)
payload = {
"sub": user_id,
"role": role,
"exp": expires,
"type": "access",
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
def create_refresh_token(user_id: str, role: str) -> str:
expires = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_DAYS
)
payload = {
"sub": user_id,
"role": role,
"exp": expires,
"type": "refresh",
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
def create_token_pair(user_id: str, role: str) -> TokenPair:
return TokenPair(
access_token=create_access_token(user_id, role),
refresh_token=create_refresh_token(user_id, role),
expires_in=settings.JWT_ACCESS_TOKEN_MINUTES * 60,
)
def decode_token(token: str) -> TokenPayload:
"""Decode and validate a JWT token."""
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[ALGORITHM])
return TokenPayload(**payload)
except JWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {e}",
headers={"WWW-Authenticate": "Bearer"},
)
# ── FastAPI dependencies ──────────────────────────────────────────────
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db),
) -> User:
"""Extract and validate the current user from JWT.
When AUTH is disabled (no JWT secret configured), returns a default analyst user.
"""
# If auth is disabled (dev mode), return a default user
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
return User(
id="dev-user",
username="analyst",
email="analyst@local",
role="analyst",
display_name="Dev Analyst",
)
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required",
headers={"WWW-Authenticate": "Bearer"},
)
token_data = decode_token(credentials.credentials)
if token_data.type != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type — use access token",
)
result = await db.execute(select(User).where(User.id == token_data.sub))
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="User account is disabled",
)
return user
async def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db),
) -> Optional[User]:
"""Like get_current_user, but returns None instead of raising if no token."""
if not credentials:
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
return User(
id="dev-user",
username="analyst",
email="analyst@local",
role="analyst",
display_name="Dev Analyst",
)
return None
try:
return await get_current_user(credentials, db)
except HTTPException:
return None
def require_role(*roles: str):
"""Dependency factory that requires the current user to have one of the specified roles."""
async def _check(user: User = Depends(get_current_user)) -> User:
if user.role not in roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires one of roles: {', '.join(roles)}. You have: {user.role}",
)
return user
return _check
# Convenience dependencies
require_analyst = require_role("analyst", "admin")
require_admin = require_role("admin")

View File

@@ -0,0 +1,400 @@
"""Cross-hunt correlation engine — find IOC overlaps, timeline patterns, and shared TTPs.
Identifies connections between hunts by analyzing:
1. Shared IOC values across datasets
2. Overlapping time ranges and temporal proximity
3. Common MITRE ATT&CK techniques across hypotheses
4. Host-to-host lateral movement patterns
"""
import logging
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
from sqlalchemy import select, func, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow, Hunt, Hypothesis, EnrichmentResult
logger = logging.getLogger(__name__)
@dataclass
class IOCOverlap:
"""Shared IOC between two or more hunts/datasets."""
ioc_value: str
ioc_type: str
datasets: list[dict] = field(default_factory=list) # [{dataset_id, hunt_id, name}]
hunt_ids: list[str] = field(default_factory=list)
count: int = 0
enrichment_verdict: str = ""
@dataclass
class TimeOverlap:
"""Overlapping time window between datasets."""
dataset_a: dict = field(default_factory=dict)
dataset_b: dict = field(default_factory=dict)
overlap_start: str = ""
overlap_end: str = ""
overlap_hours: float = 0.0
@dataclass
class TechniqueOverlap:
"""Shared MITRE ATT&CK technique across hunts."""
technique_id: str
technique_name: str = ""
hypotheses: list[dict] = field(default_factory=list)
hunt_ids: list[str] = field(default_factory=list)
@dataclass
class CorrelationResult:
"""Complete correlation analysis result."""
hunt_ids: list[str]
ioc_overlaps: list[IOCOverlap] = field(default_factory=list)
time_overlaps: list[TimeOverlap] = field(default_factory=list)
technique_overlaps: list[TechniqueOverlap] = field(default_factory=list)
host_overlaps: list[dict] = field(default_factory=list)
summary: str = ""
total_correlations: int = 0
class CorrelationEngine:
"""Engine for finding correlations across hunts and datasets."""
async def correlate_hunts(
self,
hunt_ids: list[str],
db: AsyncSession,
) -> CorrelationResult:
"""Run full correlation analysis across specified hunts."""
result = CorrelationResult(hunt_ids=hunt_ids)
# Run all correlation types
result.ioc_overlaps = await self._find_ioc_overlaps(hunt_ids, db)
result.time_overlaps = await self._find_time_overlaps(hunt_ids, db)
result.technique_overlaps = await self._find_technique_overlaps(hunt_ids, db)
result.host_overlaps = await self._find_host_overlaps(hunt_ids, db)
result.total_correlations = (
len(result.ioc_overlaps)
+ len(result.time_overlaps)
+ len(result.technique_overlaps)
+ len(result.host_overlaps)
)
result.summary = self._build_summary(result)
return result
async def correlate_all(self, db: AsyncSession) -> CorrelationResult:
"""Correlate across ALL hunts in the system."""
stmt = select(Hunt.id)
result = await db.execute(stmt)
hunt_ids = [row[0] for row in result.fetchall()]
if len(hunt_ids) < 2:
return CorrelationResult(
hunt_ids=hunt_ids,
summary="Need at least 2 hunts for correlation analysis.",
)
return await self.correlate_hunts(hunt_ids, db)
async def find_ioc_across_hunts(
self,
ioc_value: str,
db: AsyncSession,
) -> list[dict]:
"""Find all occurrences of a specific IOC across all datasets/hunts."""
# Search in dataset rows using JSON contains
stmt = select(DatasetRow, Dataset).join(
Dataset, DatasetRow.dataset_id == Dataset.id
)
result = await db.execute(stmt.limit(5000))
rows = result.all()
occurrences = []
for row, dataset in rows:
data = row.data or {}
normalized = row.normalized_data or {}
# Search both raw and normalized data
for col, val in {**data, **normalized}.items():
if str(val) == ioc_value:
occurrences.append({
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"hunt_id": dataset.hunt_id,
"row_index": row.row_index,
"column": col,
})
break
return occurrences
# ── IOC overlap detection ─────────────────────────────────────────
async def _find_ioc_overlaps(
self,
hunt_ids: list[str],
db: AsyncSession,
) -> list[IOCOverlap]:
"""Find IOC values that appear in datasets from different hunts."""
# Get all datasets for the specified hunts
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
result = await db.execute(stmt)
datasets = result.scalars().all()
if len(datasets) < 2:
return []
# Build IOC → dataset mapping
ioc_map: dict[str, list[dict]] = defaultdict(list)
for dataset in datasets:
if not dataset.ioc_columns:
continue
ioc_cols = list(dataset.ioc_columns.keys())
rows_stmt = select(DatasetRow).where(
DatasetRow.dataset_id == dataset.id
).limit(2000)
rows_result = await db.execute(rows_stmt)
rows = rows_result.scalars().all()
for row in rows:
data = row.data or {}
for col in ioc_cols:
val = data.get(col, "")
if val and str(val).strip():
ioc_map[str(val).strip()].append({
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"hunt_id": dataset.hunt_id,
"column": col,
"ioc_type": dataset.ioc_columns.get(col, "unknown"),
})
# Filter to IOCs appearing in multiple hunts
overlaps = []
for ioc_value, appearances in ioc_map.items():
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
if len(hunt_set) >= 2:
# Check for enrichment data
enrich_stmt = select(EnrichmentResult).where(
EnrichmentResult.ioc_value == ioc_value
).limit(1)
enrich_result = await db.execute(enrich_stmt)
enrichment = enrich_result.scalar_one_or_none()
overlaps.append(IOCOverlap(
ioc_value=ioc_value,
ioc_type=appearances[0].get("ioc_type", "unknown"),
datasets=appearances,
hunt_ids=sorted(hunt_set),
count=len(appearances),
enrichment_verdict=enrichment.verdict if enrichment else "",
))
# Sort by count descending
overlaps.sort(key=lambda x: x.count, reverse=True)
return overlaps[:100] # Limit results
# ── Time window overlap ───────────────────────────────────────────
async def _find_time_overlaps(
self,
hunt_ids: list[str],
db: AsyncSession,
) -> list[TimeOverlap]:
"""Find datasets across hunts with overlapping time ranges."""
stmt = select(Dataset).where(
Dataset.hunt_id.in_(hunt_ids),
Dataset.time_range_start.isnot(None),
Dataset.time_range_end.isnot(None),
)
result = await db.execute(stmt)
datasets = result.scalars().all()
overlaps = []
for i, ds_a in enumerate(datasets):
for ds_b in datasets[i + 1:]:
if ds_a.hunt_id == ds_b.hunt_id:
continue # Same hunt, skip
try:
a_start = datetime.fromisoformat(ds_a.time_range_start)
a_end = datetime.fromisoformat(ds_a.time_range_end)
b_start = datetime.fromisoformat(ds_b.time_range_start)
b_end = datetime.fromisoformat(ds_b.time_range_end)
except (ValueError, TypeError):
continue
# Check overlap
overlap_start = max(a_start, b_start)
overlap_end = min(a_end, b_end)
if overlap_start < overlap_end:
hours = (overlap_end - overlap_start).total_seconds() / 3600
overlaps.append(TimeOverlap(
dataset_a={
"id": ds_a.id,
"name": ds_a.name,
"hunt_id": ds_a.hunt_id,
"start": ds_a.time_range_start,
"end": ds_a.time_range_end,
},
dataset_b={
"id": ds_b.id,
"name": ds_b.name,
"hunt_id": ds_b.hunt_id,
"start": ds_b.time_range_start,
"end": ds_b.time_range_end,
},
overlap_start=overlap_start.isoformat(),
overlap_end=overlap_end.isoformat(),
overlap_hours=round(hours, 2),
))
overlaps.sort(key=lambda x: x.overlap_hours, reverse=True)
return overlaps[:50]
# ── MITRE technique overlap ───────────────────────────────────────
async def _find_technique_overlaps(
self,
hunt_ids: list[str],
db: AsyncSession,
) -> list[TechniqueOverlap]:
"""Find MITRE ATT&CK techniques shared across hunts."""
stmt = select(Hypothesis).where(
Hypothesis.hunt_id.in_(hunt_ids),
Hypothesis.mitre_technique.isnot(None),
)
result = await db.execute(stmt)
hypotheses = result.scalars().all()
technique_map: dict[str, list[dict]] = defaultdict(list)
for hyp in hypotheses:
technique = hyp.mitre_technique.strip()
if technique:
technique_map[technique].append({
"hypothesis_id": hyp.id,
"hypothesis_title": hyp.title,
"hunt_id": hyp.hunt_id,
"status": hyp.status,
})
overlaps = []
for technique, hyps in technique_map.items():
hunt_set = set(h["hunt_id"] for h in hyps if h["hunt_id"])
if len(hunt_set) >= 2:
overlaps.append(TechniqueOverlap(
technique_id=technique,
hypotheses=hyps,
hunt_ids=sorted(hunt_set),
))
return overlaps
# ── Host overlap ──────────────────────────────────────────────────
async def _find_host_overlaps(
self,
hunt_ids: list[str],
db: AsyncSession,
) -> list[dict]:
"""Find hostnames that appear in datasets from different hunts.
Useful for detecting lateral movement patterns.
"""
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
result = await db.execute(stmt)
datasets = result.scalars().all()
host_map: dict[str, list[dict]] = defaultdict(list)
for dataset in datasets:
norm_cols = dataset.normalized_columns or {}
# Look for hostname columns
hostname_cols = [
orig for orig, canon in norm_cols.items()
if canon in ("hostname", "host", "computer_name", "src_host", "dst_host")
]
if not hostname_cols:
continue
rows_stmt = select(DatasetRow).where(
DatasetRow.dataset_id == dataset.id
).limit(2000)
rows_result = await db.execute(rows_stmt)
rows = rows_result.scalars().all()
for row in rows:
data = row.data or {}
for col in hostname_cols:
val = data.get(col, "")
if val and str(val).strip():
host_name = str(val).strip().upper()
host_map[host_name].append({
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"hunt_id": dataset.hunt_id,
})
# Filter to hosts appearing in multiple hunts
overlaps = []
for host, appearances in host_map.items():
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
if len(hunt_set) >= 2:
overlaps.append({
"hostname": host,
"hunt_ids": sorted(hunt_set),
"dataset_count": len(appearances),
"datasets": appearances[:10],
})
overlaps.sort(key=lambda x: x["dataset_count"], reverse=True)
return overlaps[:50]
# ── Summary builder ───────────────────────────────────────────────
def _build_summary(self, result: CorrelationResult) -> str:
"""Build a human-readable summary of correlations."""
parts = [f"Correlation analysis across {len(result.hunt_ids)} hunts:"]
if result.ioc_overlaps:
malicious = [o for o in result.ioc_overlaps if o.enrichment_verdict == "malicious"]
parts.append(
f" - {len(result.ioc_overlaps)} shared IOCs "
f"({len(malicious)} flagged malicious)"
)
else:
parts.append(" - No shared IOCs found")
if result.time_overlaps:
parts.append(f" - {len(result.time_overlaps)} overlapping time windows")
if result.technique_overlaps:
parts.append(
f" - {len(result.technique_overlaps)} shared MITRE techniques"
)
if result.host_overlaps:
parts.append(
f" - {len(result.host_overlaps)} hosts appearing in multiple hunts "
"(potential lateral movement)"
)
if result.total_correlations == 0:
parts.append(" No significant correlations detected.")
return "\n".join(parts)
# Singleton
correlation_engine = CorrelationEngine()

View File

@@ -0,0 +1,165 @@
"""CSV parsing engine with encoding detection, delimiter sniffing, and streaming.
Handles large Velociraptor CSV exports with resilience to encoding issues,
varied delimiters, and malformed rows.
"""
import csv
import io
import logging
from pathlib import Path
from typing import AsyncIterator
import chardet
logger = logging.getLogger(__name__)
# Reasonable defaults
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB per field
csv.field_size_limit(MAX_FIELD_SIZE)
def detect_encoding(file_bytes: bytes, sample_size: int = 65536) -> str:
"""Detect file encoding from a sample of bytes."""
result = chardet.detect(file_bytes[:sample_size])
encoding = result.get("encoding", "utf-8") or "utf-8"
confidence = result.get("confidence", 0)
logger.info(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
# Fall back to utf-8 if confidence is very low
if confidence < 0.5:
encoding = "utf-8"
return encoding
def detect_delimiter(text_sample: str) -> str:
"""Sniff the CSV delimiter from a text sample."""
try:
dialect = csv.Sniffer().sniff(text_sample, delimiters=",\t;|")
return dialect.delimiter
except csv.Error:
return ","
def infer_column_types(rows: list[dict], sample_size: int = 100) -> dict[str, str]:
"""Infer column types from a sample of rows.
Returns a mapping of column_name → type_hint where type_hint is one of:
timestamp, integer, float, ip, hash_md5, hash_sha1, hash_sha256, domain, path, string
"""
import re
type_map: dict[str, dict[str, int]] = {}
sample = rows[:sample_size]
patterns = {
"ip": re.compile(
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
),
"hash_md5": re.compile(r"^[a-fA-F0-9]{32}$"),
"hash_sha1": re.compile(r"^[a-fA-F0-9]{40}$"),
"hash_sha256": re.compile(r"^[a-fA-F0-9]{64}$"),
"integer": re.compile(r"^-?\d+$"),
"float": re.compile(r"^-?\d+\.\d+$"),
"timestamp": re.compile(
r"^\d{4}[-/]\d{2}[-/]\d{2}[T ]\d{2}:\d{2}"
),
"domain": re.compile(
r"^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?(\.[a-zA-Z]{2,})+$"
),
"path": re.compile(r"^([A-Z]:\\|/)", re.IGNORECASE),
}
for row in sample:
for col, val in row.items():
if col not in type_map:
type_map[col] = {}
val_str = str(val).strip()
if not val_str:
continue
matched = False
for type_name, pattern in patterns.items():
if pattern.match(val_str):
type_map[col][type_name] = type_map[col].get(type_name, 0) + 1
matched = True
break
if not matched:
type_map[col]["string"] = type_map[col].get("string", 0) + 1
result: dict[str, str] = {}
for col, counts in type_map.items():
if counts:
result[col] = max(counts, key=counts.get) # type: ignore[arg-type]
else:
result[col] = "string"
return result
def parse_csv_bytes(
raw_bytes: bytes,
max_rows: int | None = None,
) -> tuple[list[dict], dict]:
"""Parse a CSV file from raw bytes.
Returns:
(rows, metadata) where metadata contains encoding, delimiter, columns, etc.
"""
encoding = detect_encoding(raw_bytes)
try:
text = raw_bytes.decode(encoding, errors="replace")
except (UnicodeDecodeError, LookupError):
text = raw_bytes.decode("utf-8", errors="replace")
encoding = "utf-8"
# Detect delimiter from first few KB
delimiter = detect_delimiter(text[:8192])
reader = csv.DictReader(io.StringIO(text), delimiter=delimiter)
columns = reader.fieldnames or []
rows: list[dict] = []
for i, row in enumerate(reader):
if max_rows is not None and i >= max_rows:
break
rows.append(dict(row))
column_types = infer_column_types(rows) if rows else {}
metadata = {
"encoding": encoding,
"delimiter": delimiter,
"columns": columns,
"column_types": column_types,
"row_count": len(rows),
"total_rows_in_file": len(rows), # same when no max_rows
}
return rows, metadata
async def parse_csv_streaming(
file_path: Path,
chunk_size: int = 8192,
) -> AsyncIterator[tuple[int, dict]]:
"""Stream-parse a CSV file yielding (row_index, row_dict) tuples.
Memory-efficient for large files.
"""
import aiofiles # type: ignore[import-untyped]
# Read a sample for encoding/delimiter detection
with open(file_path, "rb") as f:
sample_bytes = f.read(65536)
encoding = detect_encoding(sample_bytes)
text_sample = sample_bytes.decode(encoding, errors="replace")
delimiter = detect_delimiter(text_sample[:8192])
# Now stream-read
async with aiofiles.open(file_path, mode="r", encoding=encoding, errors="replace") as f:
content = await f.read() # For DictReader compatibility
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
for i, row in enumerate(reader):
yield i, dict(row)

View File

@@ -0,0 +1,655 @@
"""IOC Enrichment Engine — VirusTotal, AbuseIPDB, Shodan integrations.
Provides automated IOC enrichment with caching and rate limiting.
Enriches IPs, hashes, domains with threat intelligence verdicts.
"""
import asyncio
import hashlib
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Optional
import httpx
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db.models import EnrichmentResult as EnrichmentDB
logger = logging.getLogger(__name__)
class IOCType(str, Enum):
IP = "ip"
DOMAIN = "domain"
HASH_MD5 = "hash_md5"
HASH_SHA1 = "hash_sha1"
HASH_SHA256 = "hash_sha256"
URL = "url"
class Verdict(str, Enum):
CLEAN = "clean"
SUSPICIOUS = "suspicious"
MALICIOUS = "malicious"
UNKNOWN = "unknown"
ERROR = "error"
@dataclass
class EnrichmentResultData:
"""Enrichment result from a provider."""
ioc_value: str
ioc_type: IOCType
source: str
verdict: Verdict
score: float = 0.0 # 0-100 normalized threat score
raw_data: dict = field(default_factory=dict)
tags: list[str] = field(default_factory=list)
country: str = ""
asn: str = ""
org: str = ""
last_seen: str = ""
error: str = ""
latency_ms: int = 0
# ── Rate limiter ──────────────────────────────────────────────────────
class RateLimiter:
"""Simple token bucket rate limiter for API calls."""
def __init__(self, calls_per_minute: int = 4):
self.calls_per_minute = calls_per_minute
self.interval = 60.0 / calls_per_minute
self._last_call: float = 0.0
self._lock = asyncio.Lock()
async def acquire(self):
async with self._lock:
now = time.monotonic()
elapsed = now - self._last_call
if elapsed < self.interval:
await asyncio.sleep(self.interval - elapsed)
self._last_call = time.monotonic()
# ── Provider base ─────────────────────────────────────────────────────
class EnrichmentProvider:
"""Base class for enrichment providers."""
name: str = "base"
def __init__(self, api_key: str = "", rate_limit: int = 4):
self.api_key = api_key
self.rate_limiter = RateLimiter(rate_limit)
self._client: httpx.AsyncClient | None = None
def _get_client(self) -> httpx.AsyncClient:
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=5),
)
return self._client
async def cleanup(self):
if self._client and not self._client.is_closed:
await self._client.aclose()
@property
def is_configured(self) -> bool:
return bool(self.api_key)
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
raise NotImplementedError
# ── VirusTotal ────────────────────────────────────────────────────────
class VirusTotalProvider(EnrichmentProvider):
"""VirusTotal v3 API provider."""
name = "virustotal"
BASE_URL = "https://www.virustotal.com/api/v3"
def __init__(self):
super().__init__(api_key=settings.VIRUSTOTAL_API_KEY, rate_limit=4)
def _headers(self) -> dict:
return {"x-apikey": self.api_key}
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
if not self.is_configured:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error="VirusTotal API key not configured",
)
await self.rate_limiter.acquire()
start = time.monotonic()
try:
endpoint = self._get_endpoint(ioc_value, ioc_type)
if not endpoint:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error=f"Unsupported IOC type: {ioc_type}",
)
client = self._get_client()
resp = await client.get(endpoint, headers=self._headers())
latency_ms = int((time.monotonic() - start) * 1000)
if resp.status_code == 404:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.UNKNOWN,
latency_ms=latency_ms,
)
resp.raise_for_status()
data = resp.json()
attrs = data.get("data", {}).get("attributes", {})
stats = attrs.get("last_analysis_stats", {})
malicious = stats.get("malicious", 0)
suspicious = stats.get("suspicious", 0)
total = sum(stats.values()) if stats else 0
# Determine verdict
if malicious > 3:
verdict = Verdict.MALICIOUS
elif malicious > 0 or suspicious > 2:
verdict = Verdict.SUSPICIOUS
elif total > 0:
verdict = Verdict.CLEAN
else:
verdict = Verdict.UNKNOWN
score = (malicious / total * 100) if total > 0 else 0
tags = attrs.get("tags", [])
if attrs.get("type_description"):
tags.append(attrs["type_description"])
return EnrichmentResultData(
ioc_value=ioc_value,
ioc_type=ioc_type,
source=self.name,
verdict=verdict,
score=round(score, 1),
raw_data={
"stats": stats,
"reputation": attrs.get("reputation", 0),
"type_description": attrs.get("type_description", ""),
"names": attrs.get("names", [])[:5],
},
tags=tags[:10],
country=attrs.get("country", ""),
asn=str(attrs.get("asn", "")),
org=attrs.get("as_owner", ""),
last_seen=attrs.get("last_analysis_date", ""),
latency_ms=latency_ms,
)
except httpx.HTTPStatusError as e:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error=f"HTTP {e.response.status_code}",
latency_ms=int((time.monotonic() - start) * 1000),
)
except Exception as e:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error=str(e),
latency_ms=int((time.monotonic() - start) * 1000),
)
def _get_endpoint(self, ioc_value: str, ioc_type: IOCType) -> str | None:
if ioc_type == IOCType.IP:
return f"{self.BASE_URL}/ip_addresses/{ioc_value}"
elif ioc_type == IOCType.DOMAIN:
return f"{self.BASE_URL}/domains/{ioc_value}"
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
return f"{self.BASE_URL}/files/{ioc_value}"
elif ioc_type == IOCType.URL:
url_id = hashlib.sha256(ioc_value.encode()).hexdigest()
return f"{self.BASE_URL}/urls/{url_id}"
return None
# ── AbuseIPDB ─────────────────────────────────────────────────────────
class AbuseIPDBProvider(EnrichmentProvider):
"""AbuseIPDB API provider — IP reputation."""
name = "abuseipdb"
BASE_URL = "https://api.abuseipdb.com/api/v2"
def __init__(self):
super().__init__(api_key=settings.ABUSEIPDB_API_KEY, rate_limit=10)
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
if ioc_type != IOCType.IP:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error="AbuseIPDB only supports IP lookups",
)
if not self.is_configured:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error="AbuseIPDB API key not configured",
)
await self.rate_limiter.acquire()
start = time.monotonic()
try:
client = self._get_client()
resp = await client.get(
f"{self.BASE_URL}/check",
params={"ipAddress": ioc_value, "maxAgeInDays": 90, "verbose": "true"},
headers={"Key": self.api_key, "Accept": "application/json"},
)
latency_ms = int((time.monotonic() - start) * 1000)
resp.raise_for_status()
data = resp.json().get("data", {})
abuse_score = data.get("abuseConfidenceScore", 0)
total_reports = data.get("totalReports", 0)
if abuse_score >= 75:
verdict = Verdict.MALICIOUS
elif abuse_score >= 25 or total_reports > 5:
verdict = Verdict.SUSPICIOUS
elif total_reports == 0:
verdict = Verdict.UNKNOWN
else:
verdict = Verdict.CLEAN
categories = data.get("reports", [])
tags = []
for report in categories[:10]:
for cat_id in report.get("categories", []):
tag = self._category_name(cat_id)
if tag and tag not in tags:
tags.append(tag)
return EnrichmentResultData(
ioc_value=ioc_value,
ioc_type=ioc_type,
source=self.name,
verdict=verdict,
score=float(abuse_score),
raw_data={
"abuse_confidence_score": abuse_score,
"total_reports": total_reports,
"is_whitelisted": data.get("isWhitelisted"),
"is_tor": data.get("isTor", False),
"usage_type": data.get("usageType", ""),
"isp": data.get("isp", ""),
},
tags=tags[:10],
country=data.get("countryCode", ""),
org=data.get("isp", ""),
last_seen=data.get("lastReportedAt", ""),
latency_ms=latency_ms,
)
except Exception as e:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error=str(e),
latency_ms=int((time.monotonic() - start) * 1000),
)
@staticmethod
def _category_name(cat_id: int) -> str:
categories = {
1: "DNS Compromise", 2: "DNS Poisoning", 3: "Fraud Orders",
4: "DDoS Attack", 5: "FTP Brute-Force", 6: "Ping of Death",
7: "Phishing", 8: "Fraud VoIP", 9: "Open Proxy",
10: "Web Spam", 11: "Email Spam", 12: "Blog Spam",
13: "VPN IP", 14: "Port Scan", 15: "Hacking",
16: "SQL Injection", 17: "Spoofing", 18: "Brute-Force",
19: "Bad Web Bot", 20: "Exploited Host", 21: "Web App Attack",
22: "SSH", 23: "IoT Targeted",
}
return categories.get(cat_id, "")
# ── Shodan ────────────────────────────────────────────────────────────
class ShodanProvider(EnrichmentProvider):
"""Shodan API provider — infrastructure intelligence."""
name = "shodan"
BASE_URL = "https://api.shodan.io"
def __init__(self):
super().__init__(api_key=settings.SHODAN_API_KEY, rate_limit=1)
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
if ioc_type != IOCType.IP:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error="Shodan only supports IP lookups",
)
if not self.is_configured:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error="Shodan API key not configured",
)
await self.rate_limiter.acquire()
start = time.monotonic()
try:
client = self._get_client()
resp = await client.get(
f"{self.BASE_URL}/shodan/host/{ioc_value}",
params={"key": self.api_key, "minify": "true"},
)
latency_ms = int((time.monotonic() - start) * 1000)
if resp.status_code == 404:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.UNKNOWN,
latency_ms=latency_ms,
)
resp.raise_for_status()
data = resp.json()
ports = data.get("ports", [])
vulns = data.get("vulns", [])
tags_raw = data.get("tags", [])
# Determine verdict based on open ports and vulns
if vulns:
verdict = Verdict.SUSPICIOUS
score = min(len(vulns) * 15, 100.0)
elif len(ports) > 20:
verdict = Verdict.SUSPICIOUS
score = 40.0
else:
verdict = Verdict.CLEAN
score = 0.0
tags = tags_raw[:10]
if vulns:
tags.extend([f"CVE: {v}" for v in vulns[:5]])
return EnrichmentResultData(
ioc_value=ioc_value,
ioc_type=ioc_type,
source=self.name,
verdict=verdict,
score=score,
raw_data={
"ports": ports[:20],
"vulns": vulns[:10],
"os": data.get("os"),
"hostnames": data.get("hostnames", [])[:5],
"domains": data.get("domains", [])[:5],
"last_update": data.get("last_update", ""),
},
tags=tags[:15],
country=data.get("country_code", ""),
asn=data.get("asn", ""),
org=data.get("org", ""),
last_seen=data.get("last_update", ""),
latency_ms=latency_ms,
)
except Exception as e:
return EnrichmentResultData(
ioc_value=ioc_value, ioc_type=ioc_type,
source=self.name, verdict=Verdict.ERROR,
error=str(e),
latency_ms=int((time.monotonic() - start) * 1000),
)
# ── Enrichment Engine (orchestrator) ──────────────────────────────────
class EnrichmentEngine:
"""Orchestrates IOC enrichment across all providers with caching."""
CACHE_TTL_HOURS = 24
def __init__(self):
self.providers: list[EnrichmentProvider] = [
VirusTotalProvider(),
AbuseIPDBProvider(),
ShodanProvider(),
]
@property
def configured_providers(self) -> list[EnrichmentProvider]:
return [p for p in self.providers if p.is_configured]
async def enrich_ioc(
self,
ioc_value: str,
ioc_type: IOCType,
db: AsyncSession | None = None,
skip_cache: bool = False,
) -> list[EnrichmentResultData]:
"""Enrich a single IOC across all configured providers.
Uses cached results from DB when available.
"""
results: list[EnrichmentResultData] = []
# Check cache first
if db and not skip_cache:
cached = await self._get_cached(db, ioc_value, ioc_type)
if cached:
logger.info(f"Cache hit for {ioc_type.value}:{ioc_value} ({len(cached)} results)")
return cached
# Query all applicable providers in parallel
tasks = []
for provider in self.configured_providers:
# Skip providers that don't support this IOC type
if ioc_type in (IOCType.DOMAIN,) and provider.name in ("abuseipdb", "shodan"):
continue
if ioc_type == IOCType.IP and provider.name == "virustotal":
tasks.append(provider.enrich(ioc_value, ioc_type))
elif ioc_type == IOCType.IP:
tasks.append(provider.enrich(ioc_value, ioc_type))
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
if provider.name == "virustotal":
tasks.append(provider.enrich(ioc_value, ioc_type))
elif ioc_type == IOCType.DOMAIN:
if provider.name == "virustotal":
tasks.append(provider.enrich(ioc_value, ioc_type))
elif ioc_type == IOCType.URL:
if provider.name == "virustotal":
tasks.append(provider.enrich(ioc_value, ioc_type))
if tasks:
results = list(await asyncio.gather(*tasks, return_exceptions=False))
# Cache results
if db and results:
await self._cache_results(db, results)
return results
async def enrich_batch(
self,
iocs: list[tuple[str, IOCType]],
db: AsyncSession | None = None,
concurrency: int = 3,
) -> dict[str, list[EnrichmentResultData]]:
"""Enrich a batch of IOCs with controlled concurrency."""
sem = asyncio.Semaphore(concurrency)
all_results: dict[str, list[EnrichmentResultData]] = {}
async def _enrich_one(value: str, ioc_type: IOCType):
async with sem:
result = await self.enrich_ioc(value, ioc_type, db=db)
all_results[value] = result
tasks = [_enrich_one(v, t) for v, t in iocs]
await asyncio.gather(*tasks, return_exceptions=True)
return all_results
async def enrich_dataset_iocs(
self,
rows: list[dict],
ioc_columns: dict,
db: AsyncSession | None = None,
max_iocs: int = 50,
) -> dict[str, list[EnrichmentResultData]]:
"""Auto-enrich IOCs found in a dataset.
Extracts unique IOC values from the identified columns and enriches them.
"""
iocs_to_enrich: list[tuple[str, IOCType]] = []
seen = set()
for col_name, col_type in ioc_columns.items():
ioc_type = self._map_column_type(col_type)
if not ioc_type:
continue
for row in rows:
value = row.get(col_name, "")
if value and value not in seen:
seen.add(value)
iocs_to_enrich.append((str(value), ioc_type))
if len(iocs_to_enrich) >= max_iocs:
break
if len(iocs_to_enrich) >= max_iocs:
break
if iocs_to_enrich:
return await self.enrich_batch(iocs_to_enrich, db=db)
return {}
async def _get_cached(
self,
db: AsyncSession,
ioc_value: str,
ioc_type: IOCType,
) -> list[EnrichmentResultData] | None:
"""Check for cached enrichment results."""
cutoff = datetime.now(timezone.utc) - timedelta(hours=self.CACHE_TTL_HOURS)
stmt = (
select(EnrichmentDB)
.where(
EnrichmentDB.ioc_value == ioc_value,
EnrichmentDB.ioc_type == ioc_type.value,
EnrichmentDB.cached_at >= cutoff,
)
)
result = await db.execute(stmt)
cached = result.scalars().all()
if not cached:
return None
return [
EnrichmentResultData(
ioc_value=c.ioc_value,
ioc_type=IOCType(c.ioc_type),
source=c.source,
verdict=Verdict(c.verdict),
score=c.score or 0.0,
raw_data=c.raw_data or {},
tags=c.tags or [],
country=c.country or "",
asn=c.asn or "",
org=c.org or "",
)
for c in cached
]
async def _cache_results(
self,
db: AsyncSession,
results: list[EnrichmentResultData],
):
"""Cache enrichment results in the database."""
for r in results:
if r.verdict == Verdict.ERROR:
continue # Don't cache errors
entry = EnrichmentDB(
ioc_value=r.ioc_value,
ioc_type=r.ioc_type.value,
source=r.source,
verdict=r.verdict.value,
score=r.score,
raw_data=r.raw_data,
tags=r.tags,
country=r.country,
asn=r.asn,
org=r.org,
)
db.add(entry)
try:
await db.flush()
except Exception as e:
logger.warning(f"Failed to cache enrichment: {e}")
@staticmethod
def _map_column_type(col_type: str) -> IOCType | None:
"""Map column type from normalizer to IOCType."""
mapping = {
"ip": IOCType.IP,
"ip_address": IOCType.IP,
"src_ip": IOCType.IP,
"dst_ip": IOCType.IP,
"domain": IOCType.DOMAIN,
"hash_md5": IOCType.HASH_MD5,
"hash_sha1": IOCType.HASH_SHA1,
"hash_sha256": IOCType.HASH_SHA256,
"url": IOCType.URL,
}
return mapping.get(col_type)
async def cleanup(self):
for provider in self.providers:
await provider.cleanup()
def status(self) -> dict:
"""Return enrichment engine status."""
return {
"providers": {
p.name: {"configured": p.is_configured}
for p in self.providers
},
"cache_ttl_hours": self.CACHE_TTL_HOURS,
}
# Singleton
enrichment_engine = EnrichmentEngine()

View File

@@ -0,0 +1,145 @@
"""Default AUP keyword themes and their seed keywords.
Called once on startup — only inserts themes that don't already exist,
so user edits are never overwritten.
"""
import logging
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import KeywordTheme, Keyword
logger = logging.getLogger(__name__)
# ── Default themes + keywords ─────────────────────────────────────────
DEFAULTS: dict[str, dict] = {
"Gambling": {
"color": "#f44336",
"keywords": [
"poker", "casino", "blackjack", "roulette", "sportsbook",
"sports betting", "bet365", "draftkings", "fanduel", "bovada",
"betonline", "mybookie", "slots", "slot machine", "parlay",
"wager", "bookie", "betway", "888casino", "pokerstars",
"william hill", "ladbrokes", "betfair", "unibet", "pinnacle",
],
},
"Gaming": {
"color": "#9c27b0",
"keywords": [
"steam", "steamcommunity", "steampowered", "epic games",
"epicgames", "origin.com", "battle.net", "blizzard",
"roblox", "minecraft", "fortnite", "valorant", "league of legends",
"twitch", "twitch.tv", "discord", "discord.gg", "xbox live",
"playstation network", "gog.com", "itch.io", "gamepass",
"riot games", "ubisoft", "ea.com",
],
},
"Streaming": {
"color": "#ff9800",
"keywords": [
"netflix", "hulu", "disney+", "disneyplus", "hbomax",
"amazon prime video", "peacock", "paramount+", "crunchyroll",
"funimation", "spotify", "pandora", "soundcloud", "deezer",
"tidal", "apple music", "youtube music", "pluto tv",
"tubi", "vudu", "plex",
],
},
"Downloads / Piracy": {
"color": "#ff5722",
"keywords": [
"torrent", "bittorrent", "utorrent", "qbittorrent", "piratebay",
"thepiratebay", "1337x", "rarbg", "yts", "kickass",
"limewire", "frostwire", "mega.nz", "rapidshare", "mediafire",
"zippyshare", "uploadhaven", "fitgirl", "repack", "crack",
"keygen", "warez", "nulled", "pirate", "magnet:",
],
},
"Adult Content": {
"color": "#e91e63",
"keywords": [
"pornhub", "xvideos", "xhamster", "onlyfans", "chaturbate",
"livejasmin", "brazzers", "redtube", "youporn", "xnxx",
"porn", "xxx", "nsfw", "adult content", "cam site",
"stripchat", "bongacams",
],
},
"Social Media": {
"color": "#2196f3",
"keywords": [
"facebook", "instagram", "tiktok", "snapchat", "pinterest",
"reddit", "tumblr", "myspace", "whatsapp web", "telegram web",
"signal web", "wechat web", "twitter.com", "x.com",
"threads.net", "mastodon", "bluesky",
],
},
"Job Search": {
"color": "#4caf50",
"keywords": [
"indeed", "linkedin jobs", "glassdoor", "monster.com",
"ziprecruiter", "careerbuilder", "dice.com", "hired.com",
"angel.co", "wellfound", "levels.fyi", "salary.com",
"payscale", "resume", "cover letter", "job application",
],
},
"Shopping": {
"color": "#00bcd4",
"keywords": [
"amazon.com", "ebay", "etsy", "walmart.com", "target.com",
"bestbuy", "aliexpress", "wish.com", "shein", "temu",
"wayfair", "overstock", "newegg", "zappos", "coupon",
"promo code", "add to cart",
],
},
}
async def seed_defaults(db: AsyncSession) -> int:
"""Insert default themes + keywords for any theme name not already in DB.
Returns the number of themes inserted (0 if all already exist).
"""
# Rename legacy theme names
_renames = [("Social Media (Personal)", "Social Media")]
for old_name, new_name in _renames:
old = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == old_name))
if old:
await db.execute(
KeywordTheme.__table__.update()
.where(KeywordTheme.name == old_name)
.values(name=new_name)
)
await db.commit()
logger.info("Renamed AUP theme '%s''%s'", old_name, new_name)
inserted = 0
for theme_name, meta in DEFAULTS.items():
exists = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == theme_name)
)
if exists:
continue
theme = KeywordTheme(
name=theme_name,
color=meta["color"],
enabled=True,
is_builtin=True,
)
db.add(theme)
await db.flush() # get theme.id
for kw in meta["keywords"]:
db.add(Keyword(theme_id=theme.id, value=kw))
inserted += 1
logger.info("Seeded AUP theme '%s' with %d keywords", theme_name, len(meta["keywords"]))
if inserted:
await db.commit()
logger.info("Seeded %d AUP keyword themes", inserted)
else:
logger.debug("All default AUP themes already present")
return inserted

View File

@@ -0,0 +1,196 @@
"""Artifact normalizer — maps Velociraptor and common tool columns to canonical schema.
The canonical schema provides consistent field names regardless of which tool
exported the CSV (Velociraptor, OSQuery, Sysmon, etc.).
"""
import logging
import re
from datetime import datetime
from typing import Any
logger = logging.getLogger(__name__)
# ── Column mapping: source_column_pattern → canonical_name ─────────────
# Patterns are case-insensitive regexes matched against column names.
COLUMN_MAPPINGS: list[tuple[str, str]] = [
# Timestamps
(r"^(timestamp|time|event_?time|date_?time|created?_?(at|time|date)|modified_?(at|time|date)|mtime|ctime|atime|start_?time|end_?time)$", "timestamp"),
(r"^(eventtime|system\.timecreated)$", "timestamp"),
# Host identifiers
(r"^(hostname|host|fqdn|computer_?name|system_?name|machinename|clientid)$", "hostname"),
# Operating system
(r"^(os|operating_?system|os_?version|os_?name|platform|os_?type)$", "os"),
# Source / destination IPs
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress)$", "src_ip"),
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress)$", "dst_ip"),
(r"^(ip_?address|ipaddress|ip)$", "ip_address"),
# Ports
(r"^(source_?port|src_?port|localport)$", "src_port"),
(r"^(dest_?port|dst_?port|remoteport|destinationport)$", "dst_port"),
# Process info
(r"^(process_?name|name|image|exe|executable|binary)$", "process_name"),
(r"^(pid|process_?id)$", "pid"),
(r"^(ppid|parent_?pid|parentprocessid)$", "ppid"),
(r"^(command_?line|cmdline|commandline|cmd)$", "command_line"),
(r"^(parent_?command_?line|parentcommandline)$", "parent_command_line"),
# User info
(r"^(user|username|user_?name|account_?name|subjectusername)$", "username"),
(r"^(user_?id|uid|sid|subjectusersid)$", "user_id"),
# File info
(r"^(file_?path|fullpath|full_?name|path|filepath)$", "file_path"),
(r"^(file_?name|filename|name)$", "file_name"),
(r"^(file_?size|size|bytes|length)$", "file_size"),
(r"^(extension|file_?ext)$", "file_extension"),
# Hashes
(r"^(md5|md5hash|hash_?md5)$", "hash_md5"),
(r"^(sha1|sha1hash|hash_?sha1)$", "hash_sha1"),
(r"^(sha256|sha256hash|hash_?sha256|hash|filehash)$", "hash_sha256"),
# Network
(r"^(protocol|proto)$", "protocol"),
(r"^(domain|dns_?name|query_?name|queriedname)$", "domain"),
(r"^(url|uri|request_?url)$", "url"),
# Event info
(r"^(event_?id|eventid|eid)$", "event_id"),
(r"^(event_?type|eventtype|category|action)$", "event_type"),
(r"^(description|message|msg|detail)$", "description"),
(r"^(severity|level|priority)$", "severity"),
# Registry
(r"^(reg_?key|registry_?key|targetobject)$", "registry_key"),
(r"^(reg_?value|registry_?value)$", "registry_value"),
]
def normalize_columns(columns: list[str]) -> dict[str, str]:
"""Map raw column names to canonical names.
Returns:
Dict of {raw_column_name: canonical_column_name}.
Columns with no match map to themselves (lowered + underscored).
"""
mapping: dict[str, str] = {}
used_canonical: set[str] = set()
for col in columns:
col_lower = col.strip().lower()
matched = False
for pattern, canonical in COLUMN_MAPPINGS:
if re.match(pattern, col_lower, re.IGNORECASE):
# Avoid duplicate canonical names
if canonical not in used_canonical:
mapping[col] = canonical
used_canonical.add(canonical)
matched = True
break
if not matched:
# Produce a clean snake_case version
clean = re.sub(r"[^a-z0-9]+", "_", col_lower).strip("_")
mapping[col] = clean or col
return mapping
def normalize_row(row: dict[str, Any], column_mapping: dict[str, str]) -> dict[str, Any]:
"""Apply column mapping to a single row."""
return {column_mapping.get(k, k): v for k, v in row.items()}
def normalize_rows(rows: list[dict], column_mapping: dict[str, str]) -> list[dict]:
"""Apply column mapping to all rows."""
return [normalize_row(row, column_mapping) for row in rows]
def detect_ioc_columns(
columns: list[str],
column_types: dict[str, str],
column_mapping: dict[str, str],
) -> dict[str, str]:
"""Detect which columns contain IOCs (IPs, hashes, domains).
Returns:
Dict of {column_name: ioc_type}.
"""
ioc_columns: dict[str, str] = {}
ioc_type_map = {
"ip": "ip",
"hash_md5": "hash_md5",
"hash_sha1": "hash_sha1",
"hash_sha256": "hash_sha256",
"domain": "domain",
}
for col in columns:
col_type = column_types.get(col)
if col_type in ioc_type_map:
ioc_columns[col] = ioc_type_map[col_type]
# Also check canonical name
canonical = column_mapping.get(col, "")
if canonical in ("src_ip", "dst_ip", "ip_address"):
ioc_columns[col] = "ip"
elif canonical == "hash_md5":
ioc_columns[col] = "hash_md5"
elif canonical == "hash_sha1":
ioc_columns[col] = "hash_sha1"
elif canonical in ("hash_sha256",):
ioc_columns[col] = "hash_sha256"
elif canonical == "domain":
ioc_columns[col] = "domain"
elif canonical == "url":
ioc_columns[col] = "url"
return ioc_columns
def detect_time_range(
rows: list[dict],
column_mapping: dict[str, str],
) -> tuple[datetime | None, datetime | None]:
"""Find the earliest and latest timestamps in the dataset."""
ts_col = None
for raw_col, canonical in column_mapping.items():
if canonical == "timestamp":
ts_col = raw_col
break
if not ts_col:
return None, None
timestamps: list[datetime] = []
for row in rows:
val = row.get(ts_col)
if not val:
continue
try:
dt = _parse_timestamp(str(val))
if dt:
timestamps.append(dt)
except (ValueError, TypeError):
continue
if not timestamps:
return None, None
return min(timestamps), max(timestamps)
def _parse_timestamp(value: str) -> datetime | None:
"""Try multiple timestamp formats."""
formats = [
"%Y-%m-%dT%H:%M:%S.%fZ",
"%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f",
"%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S",
"%m/%d/%Y %H:%M:%S",
"%d/%m/%Y %H:%M:%S",
]
for fmt in formats:
try:
return datetime.strptime(value.strip(), fmt)
except ValueError:
continue
return None

View File

@@ -0,0 +1,425 @@
"""Report generation — JSON, HTML, and CSV export for hunt investigations.
Generates comprehensive investigation reports including:
- Hunt metadata and status
- Dataset summaries with IOC counts
- Hypotheses and their evidence
- Annotations timeline
- Enrichment verdicts
- Agent conversation history
- Cross-hunt correlations
"""
import csv
import io
import json
import logging
from dataclasses import asdict
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import (
Hunt, Dataset, DatasetRow, Hypothesis,
Annotation, Conversation, Message, EnrichmentResult,
)
logger = logging.getLogger(__name__)
class ReportGenerator:
"""Generates exportable investigation reports."""
async def generate_hunt_report(
self,
hunt_id: str,
db: AsyncSession,
format: str = "json",
include_rows: bool = False,
max_rows: int = 500,
) -> dict | str:
"""Generate a comprehensive report for a hunt investigation."""
# Gather all hunt data
report_data = await self._gather_hunt_data(
hunt_id, db, include_rows=include_rows, max_rows=max_rows,
)
if not report_data:
return {"error": "Hunt not found"}
if format == "json":
return report_data
elif format == "html":
return self._render_html(report_data)
elif format == "csv":
return self._render_csv(report_data)
else:
return report_data
async def _gather_hunt_data(
self,
hunt_id: str,
db: AsyncSession,
include_rows: bool = False,
max_rows: int = 500,
) -> dict | None:
"""Gather all data for a hunt report."""
# Hunt metadata
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
return None
# Datasets
ds_result = await db.execute(
select(Dataset).where(Dataset.hunt_id == hunt_id)
)
datasets = ds_result.scalars().all()
dataset_summaries = []
all_iocs = {}
for ds in datasets:
summary = {
"id": ds.id,
"name": ds.name,
"filename": ds.filename,
"source_tool": ds.source_tool,
"row_count": ds.row_count,
"columns": list((ds.column_schema or {}).keys()),
"ioc_columns": ds.ioc_columns or {},
"time_range": {
"start": ds.time_range_start,
"end": ds.time_range_end,
},
"created_at": ds.created_at.isoformat(),
}
if include_rows:
rows_result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.limit(max_rows)
)
rows = rows_result.scalars().all()
summary["rows"] = [r.data for r in rows]
dataset_summaries.append(summary)
# Collect IOCs for enrichment lookup
if ds.ioc_columns:
all_iocs.update(ds.ioc_columns)
# Hypotheses
hyp_result = await db.execute(
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
)
hypotheses = hyp_result.scalars().all()
hypotheses_data = [
{
"id": h.id,
"title": h.title,
"description": h.description,
"mitre_technique": h.mitre_technique,
"status": h.status,
"evidence_row_ids": h.evidence_row_ids,
"evidence_notes": h.evidence_notes,
"created_at": h.created_at.isoformat(),
"updated_at": h.updated_at.isoformat(),
}
for h in hypotheses
]
# Annotations (across all datasets in this hunt)
dataset_ids = [ds.id for ds in datasets]
annotations_data = []
if dataset_ids:
ann_result = await db.execute(
select(Annotation)
.where(Annotation.dataset_id.in_(dataset_ids))
.order_by(Annotation.created_at)
)
annotations = ann_result.scalars().all()
annotations_data = [
{
"id": a.id,
"dataset_id": a.dataset_id,
"row_id": a.row_id,
"text": a.text,
"severity": a.severity,
"tag": a.tag,
"created_at": a.created_at.isoformat(),
}
for a in annotations
]
# Conversations
conv_result = await db.execute(
select(Conversation).where(Conversation.hunt_id == hunt_id)
)
conversations = conv_result.scalars().all()
conversations_data = []
for conv in conversations:
msg_result = await db.execute(
select(Message)
.where(Message.conversation_id == conv.id)
.order_by(Message.created_at)
)
messages = msg_result.scalars().all()
conversations_data.append({
"id": conv.id,
"title": conv.title,
"messages": [
{
"role": m.role,
"content": m.content,
"model_used": m.model_used,
"node_used": m.node_used,
"latency_ms": m.latency_ms,
"created_at": m.created_at.isoformat(),
}
for m in messages
],
})
# Enrichment results
enrichment_data = []
for ds in datasets:
if not ds.ioc_columns:
continue
# Get unique enriched IOCs for this dataset
for col_name in ds.ioc_columns.keys():
enrich_result = await db.execute(
select(EnrichmentResult)
.where(EnrichmentResult.source.isnot(None))
.limit(100)
)
enrichments = enrich_result.scalars().all()
for e in enrichments:
enrichment_data.append({
"ioc_value": e.ioc_value,
"ioc_type": e.ioc_type,
"source": e.source,
"verdict": e.verdict,
"score": e.score,
"tags": e.tags,
"country": e.country,
})
break # Only query once
# Build report
now = datetime.now(timezone.utc)
return {
"report_metadata": {
"generated_at": now.isoformat(),
"format_version": "1.0",
"generator": "ThreatHunt Report Engine",
},
"hunt": {
"id": hunt.id,
"name": hunt.name,
"description": hunt.description,
"status": hunt.status,
"created_at": hunt.created_at.isoformat(),
"updated_at": hunt.updated_at.isoformat(),
},
"summary": {
"dataset_count": len(datasets),
"total_rows": sum(ds.row_count for ds in datasets),
"hypothesis_count": len(hypotheses),
"confirmed_hypotheses": len([h for h in hypotheses if h.status == "confirmed"]),
"annotation_count": len(annotations_data),
"critical_annotations": len([a for a in annotations_data if a["severity"] == "critical"]),
"conversation_count": len(conversations_data),
"enrichment_count": len(enrichment_data),
"malicious_iocs": len([e for e in enrichment_data if e["verdict"] == "malicious"]),
},
"datasets": dataset_summaries,
"hypotheses": hypotheses_data,
"annotations": annotations_data,
"conversations": conversations_data,
"enrichments": enrichment_data[:100],
}
def _render_html(self, data: dict) -> str:
"""Render report as self-contained HTML."""
hunt = data.get("hunt", {})
summary = data.get("summary", {})
hypotheses = data.get("hypotheses", [])
annotations = data.get("annotations", [])
datasets = data.get("datasets", [])
enrichments = data.get("enrichments", [])
meta = data.get("report_metadata", {})
html = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ThreatHunt Report: {hunt.get('name', 'Unknown')}</title>
<style>
:root {{ --bg: #0d1117; --surface: #161b22; --border: #30363d; --text: #c9d1d9; --accent: #58a6ff; --red: #f85149; --orange: #d29922; --green: #3fb950; }}
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif; background: var(--bg); color: var(--text); line-height: 1.6; padding: 2rem; }}
.container {{ max-width: 1200px; margin: 0 auto; }}
h1 {{ color: var(--accent); border-bottom: 2px solid var(--border); padding-bottom: 0.5rem; margin-bottom: 1rem; }}
h2 {{ color: var(--accent); margin: 1.5rem 0 0.75rem; }}
h3 {{ color: var(--text); margin: 1rem 0 0.5rem; }}
.card {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; margin: 0.75rem 0; }}
.stat-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 0.75rem; }}
.stat {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; text-align: center; }}
.stat .value {{ font-size: 2rem; font-weight: 700; color: var(--accent); }}
.stat .label {{ font-size: 0.85rem; color: #8b949e; }}
table {{ width: 100%; border-collapse: collapse; margin: 0.5rem 0; }}
th, td {{ padding: 0.5rem 0.75rem; border: 1px solid var(--border); text-align: left; }}
th {{ background: var(--surface); color: var(--accent); }}
.badge {{ display: inline-block; padding: 0.15rem 0.5rem; border-radius: 999px; font-size: 0.8rem; font-weight: 600; }}
.badge-malicious {{ background: var(--red); color: white; }}
.badge-suspicious {{ background: var(--orange); color: #000; }}
.badge-clean {{ background: var(--green); color: #000; }}
.badge-critical {{ background: var(--red); color: white; }}
.badge-high {{ background: #da3633; color: white; }}
.badge-medium {{ background: var(--orange); color: #000; }}
.badge-confirmed {{ background: var(--green); color: #000; }}
.badge-active {{ background: var(--accent); color: #000; }}
.footer {{ margin-top: 2rem; padding-top: 1rem; border-top: 1px solid var(--border); color: #8b949e; font-size: 0.85rem; }}
</style>
</head>
<body>
<div class="container">
<h1>🔍 ThreatHunt Report: {hunt.get('name', 'Untitled')}</h1>
<p><strong>Hunt ID:</strong> {hunt.get('id', '')}<br>
<strong>Status:</strong> {hunt.get('status', 'unknown')}<br>
<strong>Description:</strong> {hunt.get('description', 'N/A')}<br>
<strong>Created:</strong> {hunt.get('created_at', '')}</p>
<h2>Summary</h2>
<div class="stat-grid">
<div class="stat"><div class="value">{summary.get('dataset_count', 0)}</div><div class="label">Datasets</div></div>
<div class="stat"><div class="value">{summary.get('total_rows', 0):,}</div><div class="label">Total Rows</div></div>
<div class="stat"><div class="value">{summary.get('hypothesis_count', 0)}</div><div class="label">Hypotheses</div></div>
<div class="stat"><div class="value">{summary.get('confirmed_hypotheses', 0)}</div><div class="label">Confirmed</div></div>
<div class="stat"><div class="value">{summary.get('annotation_count', 0)}</div><div class="label">Annotations</div></div>
<div class="stat"><div class="value">{summary.get('malicious_iocs', 0)}</div><div class="label">Malicious IOCs</div></div>
</div>
"""
# Hypotheses section
if hypotheses:
html += "<h2>Hypotheses</h2>\n"
html += "<table><tr><th>Title</th><th>MITRE</th><th>Status</th><th>Description</th></tr>\n"
for h in hypotheses:
status_class = f"badge-{h['status']}" if h['status'] in ('confirmed', 'active') else ""
html += (
f"<tr><td>{h['title']}</td>"
f"<td>{h.get('mitre_technique', 'N/A')}</td>"
f"<td><span class='badge {status_class}'>{h['status']}</span></td>"
f"<td>{h.get('description', '') or ''}</td></tr>\n"
)
html += "</table>\n"
# Datasets section
if datasets:
html += "<h2>Datasets</h2>\n"
for ds in datasets:
html += f"""<div class="card">
<h3>{ds['name']} ({ds.get('filename', '')})</h3>
<p><strong>Source:</strong> {ds.get('source_tool', 'N/A')} |
<strong>Rows:</strong> {ds['row_count']:,} |
<strong>IOC Columns:</strong> {len(ds.get('ioc_columns', {}))} |
<strong>Time Range:</strong> {ds.get('time_range', {}).get('start', 'N/A')} to {ds.get('time_range', {}).get('end', 'N/A')}</p>
</div>\n"""
# Annotations
if annotations:
critical = [a for a in annotations if a['severity'] in ('critical', 'high')]
html += f"<h2>Annotations ({len(annotations)} total, {len(critical)} critical/high)</h2>\n"
html += "<table><tr><th>Severity</th><th>Tag</th><th>Text</th><th>Created</th></tr>\n"
for a in annotations[:50]:
sev_class = f"badge-{a['severity']}" if a['severity'] in ('critical', 'high', 'medium') else ""
html += (
f"<tr><td><span class='badge {sev_class}'>{a['severity']}</span></td>"
f"<td>{a.get('tag', 'N/A')}</td>"
f"<td>{a['text'][:200]}</td>"
f"<td>{a['created_at'][:19]}</td></tr>\n"
)
html += "</table>\n"
# Enrichments
if enrichments:
malicious = [e for e in enrichments if e['verdict'] == 'malicious']
html += f"<h2>IOC Enrichment ({len(enrichments)} results, {len(malicious)} malicious)</h2>\n"
html += "<table><tr><th>IOC</th><th>Type</th><th>Source</th><th>Verdict</th><th>Score</th></tr>\n"
for e in enrichments[:50]:
verdict_class = f"badge-{e['verdict']}"
html += (
f"<tr><td><code>{e['ioc_value']}</code></td>"
f"<td>{e['ioc_type']}</td>"
f"<td>{e['source']}</td>"
f"<td><span class='badge {verdict_class}'>{e['verdict']}</span></td>"
f"<td>{e.get('score', 0)}</td></tr>\n"
)
html += "</table>\n"
html += f"""
<div class="footer">
<p>Generated by ThreatHunt Report Engine | {meta.get('generated_at', '')[:19]}</p>
</div>
</div>
</body>
</html>"""
return html
def _render_csv(self, data: dict) -> str:
"""Render key report data as CSV."""
output = io.StringIO()
# Hypotheses sheet
output.write("=== HYPOTHESES ===\n")
writer = csv.writer(output)
writer.writerow(["Title", "MITRE Technique", "Status", "Description", "Evidence Notes"])
for h in data.get("hypotheses", []):
writer.writerow([
h.get("title", ""),
h.get("mitre_technique", ""),
h.get("status", ""),
h.get("description", ""),
h.get("evidence_notes", ""),
])
output.write("\n=== ANNOTATIONS ===\n")
writer.writerow(["Severity", "Tag", "Text", "Dataset ID", "Row ID", "Created"])
for a in data.get("annotations", []):
writer.writerow([
a.get("severity", ""),
a.get("tag", ""),
a.get("text", ""),
a.get("dataset_id", ""),
a.get("row_id", ""),
a.get("created_at", ""),
])
output.write("\n=== ENRICHMENTS ===\n")
writer.writerow(["IOC Value", "IOC Type", "Source", "Verdict", "Score", "Country"])
for e in data.get("enrichments", []):
writer.writerow([
e.get("ioc_value", ""),
e.get("ioc_type", ""),
e.get("source", ""),
e.get("verdict", ""),
e.get("score", ""),
e.get("country", ""),
])
return output.getvalue()
# Singleton
report_generator = ReportGenerator()

View File

@@ -0,0 +1,346 @@
"""SANS RAG service — queries the 300GB SANS courseware indexed in Open WebUI.
Provides contextual SANS references for threat hunting guidance.
Uses two approaches:
1. Open WebUI RAG pipeline (if configured with a knowledge collection)
2. Embedding-based semantic search against locally indexed SANS content
"""
import asyncio
import logging
import re
import time
from dataclasses import dataclass, field
from typing import Optional
import httpx
from app.config import settings
from app.agents.providers_v2 import _get_client
from app.agents.registry import Node
logger = logging.getLogger(__name__)
# ── SANS course catalog for reference matching ────────────────────────
SANS_COURSES = {
"SEC401": "Security Essentials",
"SEC504": "Hacker Tools, Techniques, and Incident Handling",
"SEC503": "Network Monitoring and Threat Detection In-Depth",
"SEC505": "Securing Windows and PowerShell Automation",
"SEC506": "Securing Linux/Unix",
"SEC510": "Public Cloud Security: AWS, Azure, and GCP",
"SEC511": "Continuous Monitoring and Security Operations",
"SEC530": "Defensible Security Architecture and Engineering",
"SEC540": "Cloud Security and DevSecOps Automation",
"SEC555": "SIEM with Tactical Analytics",
"SEC560": "Enterprise Penetration Testing",
"SEC565": "Red Team Operations and Adversary Emulation",
"SEC573": "Automating Information Security with Python",
"SEC575": "Mobile Device Security and Ethical Hacking",
"SEC588": "Cloud Penetration Testing",
"SEC599": "Defeating Advanced Adversaries - Purple Team Tactics",
"FOR408": "Windows Forensic Analysis",
"FOR498": "Digital Acquisition and Rapid Triage",
"FOR500": "Windows Forensic Analysis",
"FOR508": "Advanced Incident Response, Threat Hunting, and Digital Forensics",
"FOR509": "Enterprise Cloud Forensics and Incident Response",
"FOR518": "Mac and iOS Forensic Analysis and Incident Response",
"FOR572": "Advanced Network Forensics: Threat Hunting, Analysis, and Incident Response",
"FOR578": "Cyber Threat Intelligence",
"FOR585": "Smartphone Forensic Analysis In-Depth",
"FOR610": "Reverse-Engineering Malware: Malware Analysis Tools and Techniques",
"FOR710": "Reverse-Engineering Malware: Advanced Code Analysis",
"ICS410": "ICS/SCADA Security Essentials",
"ICS515": "ICS Visibility, Detection, and Response",
}
# Topic-to-course mapping for fallback recommendations
TOPIC_COURSE_MAP = {
"malware": ["FOR610", "FOR710", "SEC504"],
"reverse engineer": ["FOR610", "FOR710"],
"incident response": ["FOR508", "SEC504"],
"forensic": ["FOR508", "FOR500", "FOR408"],
"windows forensic": ["FOR500", "FOR408"],
"network forensic": ["FOR572"],
"threat hunting": ["FOR508", "SEC504", "FOR578"],
"threat intelligence": ["FOR578"],
"powershell": ["SEC505", "FOR508"],
"lateral movement": ["SEC504", "FOR508"],
"persistence": ["FOR508", "SEC504"],
"privilege escalation": ["SEC504", "SEC560"],
"credential": ["SEC504", "SEC560"],
"memory forensic": ["FOR508"],
"disk forensic": ["FOR500", "FOR408"],
"registry": ["FOR500", "FOR408"],
"event log": ["FOR508", "SEC555"],
"siem": ["SEC555"],
"log analysis": ["SEC555", "SEC503"],
"network monitor": ["SEC503"],
"pcap": ["SEC503", "FOR572"],
"cloud": ["SEC510", "SEC540", "FOR509"],
"aws": ["SEC510", "SEC540", "FOR509"],
"azure": ["SEC510", "FOR509"],
"linux": ["SEC506"],
"mobile": ["SEC575", "FOR585"],
"penetration test": ["SEC560", "SEC565"],
"red team": ["SEC565", "SEC599"],
"purple team": ["SEC599"],
"python": ["SEC573"],
"automation": ["SEC573", "SEC540"],
"deobfusc": ["FOR610", "SEC504"],
"base64": ["FOR610", "SEC504"],
"shellcode": ["FOR610", "FOR710"],
"ransomware": ["FOR508", "FOR610"],
"phishing": ["SEC504", "FOR578"],
"c2": ["FOR508", "SEC504", "FOR572"],
"command and control": ["FOR508", "SEC504"],
"exfiltration": ["FOR508", "FOR572", "SEC503"],
"dns": ["FOR572", "SEC503"],
"ioc": ["FOR508", "FOR578"],
"mitre": ["FOR508", "SEC504", "SEC599"],
"att&ck": ["FOR508", "SEC504"],
"velociraptor": ["FOR508"],
"volatility": ["FOR508"],
"scheduled task": ["FOR508", "SEC504"],
"service": ["FOR508", "SEC504"],
"wmi": ["FOR508", "SEC504"],
"process": ["FOR508"],
"dll": ["FOR610", "FOR508"],
}
@dataclass
class RAGResult:
"""Result from a RAG query."""
query: str
context: str # Retrieved relevant text
sources: list[str] = field(default_factory=list) # Source document names
course_references: list[str] = field(default_factory=list) # SANS course IDs
confidence: float = 0.0
latency_ms: int = 0
class SANSRAGService:
"""Service for querying SANS courseware via Open WebUI RAG pipeline."""
def __init__(self):
self.openwebui_url = settings.OPENWEBUI_URL.rstrip("/")
self.api_key = settings.OPENWEBUI_API_KEY
self.rag_model = settings.DEFAULT_FAST_MODEL
self._available: bool | None = None
def _headers(self) -> dict:
h = {"Content-Type": "application/json"}
if self.api_key:
h["Authorization"] = f"Bearer {self.api_key}"
return h
async def query(
self,
question: str,
context: str = "",
max_tokens: int = 1024,
) -> RAGResult:
"""Query SANS courseware for relevant context.
Uses Open WebUI's RAG-enabled chat to retrieve from indexed SANS content.
Falls back to topic-based course recommendations if RAG is unavailable.
"""
start = time.monotonic()
# Try Open WebUI RAG pipeline first
try:
result = await self._query_openwebui_rag(question, context, max_tokens)
result.latency_ms = int((time.monotonic() - start) * 1000)
# Enrich with course references if not already present
if not result.course_references:
result.course_references = self._match_courses(question)
return result
except Exception as e:
logger.warning(f"RAG query failed, using fallback: {e}")
# Fallback to topic-based matching
courses = self._match_courses(question)
return RAGResult(
query=question,
context="",
sources=[],
course_references=courses,
confidence=0.3 if courses else 0.0,
latency_ms=int((time.monotonic() - start) * 1000),
)
async def _query_openwebui_rag(
self,
question: str,
context: str,
max_tokens: int,
) -> RAGResult:
"""Query Open WebUI with RAG context retrieval.
Open WebUI automatically retrieves from its indexed knowledge base
when the model is configured with a knowledge collection.
"""
client = _get_client()
system_msg = (
"You are a SANS cybersecurity knowledge assistant. "
"Use your indexed SANS courseware to answer the question. "
"Always cite the specific SANS course (e.g., FOR508, SEC504) "
"and relevant section when referencing material. "
"If the question relates to threat hunting procedures, "
"reference the specific SANS methodology or framework."
)
messages = [
{"role": "system", "content": system_msg},
]
if context:
messages.append({
"role": "user",
"content": f"Investigation context:\n{context}\n\nQuestion: {question}",
})
else:
messages.append({"role": "user", "content": question})
payload = {
"model": self.rag_model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.2,
"stream": False,
}
resp = await client.post(
f"{self.openwebui_url}/v1/chat/completions",
json=payload,
headers=self._headers(),
)
resp.raise_for_status()
data = resp.json()
content = ""
if data.get("choices"):
content = data["choices"][0].get("message", {}).get("content", "")
# Extract course references from response
course_refs = self._extract_course_refs(content)
sources = self._extract_sources(data)
return RAGResult(
query=question,
context=content,
sources=sources,
course_references=course_refs,
confidence=0.8 if content else 0.0,
)
def _extract_course_refs(self, text: str) -> list[str]:
"""Extract SANS course references from response text."""
refs = set()
# Match patterns like SEC504, FOR508, ICS410
pattern = r'\b(SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
matches = re.findall(pattern, text, re.IGNORECASE)
# Need to get the full match
full_pattern = r'\b(?:SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
full_matches = re.findall(full_pattern, text, re.IGNORECASE)
for m in full_matches:
course_id = m.upper()
if course_id in SANS_COURSES:
refs.add(f"{course_id}: {SANS_COURSES[course_id]}")
else:
refs.add(course_id)
return sorted(refs)
def _extract_sources(self, api_response: dict) -> list[str]:
"""Extract source document references from Open WebUI response metadata."""
sources = []
# Open WebUI may include source metadata in various formats
if "sources" in api_response:
for src in api_response["sources"]:
if isinstance(src, dict):
sources.append(src.get("name", src.get("title", str(src))))
else:
sources.append(str(src))
# Check in metadata
for choice in api_response.get("choices", []):
meta = choice.get("metadata", {})
if "sources" in meta:
for src in meta["sources"]:
if isinstance(src, dict):
sources.append(src.get("name", str(src)))
else:
sources.append(str(src))
return sources[:10] # Limit
def _match_courses(self, query: str) -> list[str]:
"""Match query keywords to SANS courses using topic map."""
q = query.lower()
matched = set()
for topic, courses in TOPIC_COURSE_MAP.items():
if topic in q:
for course_id in courses:
if course_id in SANS_COURSES:
matched.add(f"{course_id}: {SANS_COURSES[course_id]}")
return sorted(matched)[:5]
async def get_course_context(self, course_id: str) -> str:
"""Get a brief course description for context injection."""
course_id = course_id.upper().split(":")[0].strip()
if course_id in SANS_COURSES:
return f"{course_id}: {SANS_COURSES[course_id]}"
return ""
async def enrich_prompt(
self,
query: str,
investigation_context: str = "",
) -> str:
"""Generate SANS-enriched context to inject into agent prompts.
Returns a context string with relevant SANS references.
"""
result = await self.query(query, context=investigation_context, max_tokens=512)
parts = []
if result.context:
parts.append(f"SANS Reference Context:\n{result.context}")
if result.course_references:
parts.append(f"Relevant SANS Courses: {', '.join(result.course_references)}")
if result.sources:
parts.append(f"Sources: {', '.join(result.sources[:5])}")
return "\n".join(parts) if parts else ""
async def health_check(self) -> dict:
"""Check RAG service availability."""
try:
client = _get_client()
resp = await client.get(
f"{self.openwebui_url}/v1/models",
headers=self._headers(),
timeout=5,
)
available = resp.status_code == 200
self._available = available
return {
"available": available,
"url": self.openwebui_url,
"model": self.rag_model,
}
except Exception as e:
self._available = False
return {
"available": False,
"url": self.openwebui_url,
"error": str(e),
}
# Singleton
sans_rag = SANSRAGService()

View File

@@ -0,0 +1,233 @@
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
messages for keyword matches.
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
across SQLite / PostgreSQL and to provide per-cell match context.
"""
import logging
import re
from dataclasses import dataclass, field
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import (
KeywordTheme,
Keyword,
DatasetRow,
Dataset,
Hunt,
Annotation,
Message,
Conversation,
)
logger = logging.getLogger(__name__)
BATCH_SIZE = 500
@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
@dataclass
class ScanResult:
total_hits: int = 0
hits: list[ScanHit] = field(default_factory=list)
themes_scanned: int = 0
keywords_scanned: int = 0
rows_scanned: int = 0
class KeywordScanner:
"""Scans multiple data sources for keyword/regex matches."""
def __init__(self, db: AsyncSession):
self.db = db
# ── Public API ────────────────────────────────────────────────────
async def scan(
self,
dataset_ids: list[str] | None = None,
theme_ids: list[str] | None = None,
scan_hunts: bool = True,
scan_annotations: bool = True,
scan_messages: bool = True,
) -> dict:
"""Run a full AUP scan and return dict matching ScanResponse."""
# Load themes + keywords
themes = await self._load_themes(theme_ids)
if not themes:
return ScanResult().__dict__
# Pre-compile patterns per theme
patterns = self._compile_patterns(themes)
result = ScanResult(
themes_scanned=len(themes),
keywords_scanned=sum(len(kws) for kws in patterns.values()),
)
# Scan dataset rows
await self._scan_datasets(patterns, result, dataset_ids)
# Scan hunts
if scan_hunts:
await self._scan_hunts(patterns, result)
# Scan annotations
if scan_annotations:
await self._scan_annotations(patterns, result)
# Scan messages
if scan_messages:
await self._scan_messages(patterns, result)
result.total_hits = len(result.hits)
return {
"total_hits": result.total_hits,
"hits": [h.__dict__ for h in result.hits],
"themes_scanned": result.themes_scanned,
"keywords_scanned": result.keywords_scanned,
"rows_scanned": result.rows_scanned,
}
# ── Internal ──────────────────────────────────────────────────────
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
if theme_ids:
q = q.where(KeywordTheme.id.in_(theme_ids))
result = await self.db.execute(q)
return list(result.scalars().all())
def _compile_patterns(
self, themes: list[KeywordTheme]
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
for theme in themes:
key = (theme.id, theme.name, theme.color)
compiled = []
for kw in theme.keywords:
try:
if kw.is_regex:
pat = re.compile(kw.value, re.IGNORECASE)
else:
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
compiled.append((kw.value, pat))
except re.error:
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
kw.value, theme.name)
patterns[key] = compiled
return patterns
def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
) -> None:
"""Check text against all compiled patterns, append hits."""
if not text:
return
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
for kw_value, pat in keyword_patterns:
if pat.search(text):
# Truncate matched_value for display
matched_preview = text[:200] + ("" if len(text) > 200 else "")
hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
))
async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches."""
# Build dataset name lookup
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
ds_result = await self.db.execute(ds_q)
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
if not ds_map:
return
# Iterate rows in batches
offset = 0
row_q_base = select(DatasetRow).where(
DatasetRow.dataset_id.in_(list(ds_map.keys()))
).order_by(DatasetRow.id)
while True:
rows_result = await self.db.execute(
row_q_base.offset(offset).limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text, patterns, "dataset_row", row.id,
col_name, result.hits,
row_index=row.row_index,
dataset_name=ds_map.get(row.dataset_id),
)
offset += BATCH_SIZE
if len(rows) < BATCH_SIZE:
break
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
"""Scan hunt names and descriptions."""
hunts_result = await self.db.execute(select(Hunt))
for hunt in hunts_result.scalars().all():
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
if hunt.description:
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
"""Scan annotation text."""
ann_result = await self.db.execute(select(Annotation))
for ann in ann_result.scalars().all():
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
"""Scan conversation messages (user messages only)."""
msg_result = await self.db.execute(
select(Message).where(Message.role == "user")
)
for msg in msg_result.scalars().all():
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)