mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Services package."""
|
||||
201
backend/app/services/auth.py
Normal file
201
backend/app/services/auth.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Authentication & security — JWT tokens, password hashing, role-based access.
|
||||
|
||||
Provides:
|
||||
- Password hashing (bcrypt via passlib)
|
||||
- JWT access/refresh token creation and verification
|
||||
- FastAPI dependency for protecting routes
|
||||
- Role-based enforcement (analyst, admin, viewer)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ─────────────────────────────────────────────────
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
# ── JWT tokens ────────────────────────────────────────────────────────
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # user_id
|
||||
role: str
|
||||
exp: datetime
|
||||
type: str # "access" or "refresh"
|
||||
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_DAYS
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "refresh",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_token_pair(user_id: str, role: str) -> TokenPair:
|
||||
return TokenPair(
|
||||
access_token=create_access_token(user_id, role),
|
||||
refresh_token=create_refresh_token(user_id, role),
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
"""Decode and validate a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[ALGORITHM])
|
||||
return TokenPayload(**payload)
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid token: {e}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
# ── FastAPI dependencies ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract and validate the current user from JWT.
|
||||
|
||||
When AUTH is disabled (no JWT secret configured), returns a default analyst user.
|
||||
"""
|
||||
# If auth is disabled (dev mode), return a default user
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token_data = decode_token(credentials.credentials)
|
||||
|
||||
if token_data.type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use access token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Optional[User]:
|
||||
"""Like get_current_user, but returns None instead of raising if no token."""
|
||||
if not credentials:
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def require_role(*roles: str):
|
||||
"""Dependency factory that requires the current user to have one of the specified roles."""
|
||||
|
||||
async def _check(user: User = Depends(get_current_user)) -> User:
|
||||
if user.role not in roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires one of roles: {', '.join(roles)}. You have: {user.role}",
|
||||
)
|
||||
return user
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
# Convenience dependencies
|
||||
require_analyst = require_role("analyst", "admin")
|
||||
require_admin = require_role("admin")
|
||||
400
backend/app/services/correlation.py
Normal file
400
backend/app/services/correlation.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Cross-hunt correlation engine — find IOC overlaps, timeline patterns, and shared TTPs.
|
||||
|
||||
Identifies connections between hunts by analyzing:
|
||||
1. Shared IOC values across datasets
|
||||
2. Overlapping time ranges and temporal proximity
|
||||
3. Common MITRE ATT&CK techniques across hypotheses
|
||||
4. Host-to-host lateral movement patterns
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow, Hunt, Hypothesis, EnrichmentResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IOCOverlap:
|
||||
"""Shared IOC between two or more hunts/datasets."""
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
datasets: list[dict] = field(default_factory=list) # [{dataset_id, hunt_id, name}]
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
count: int = 0
|
||||
enrichment_verdict: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeOverlap:
|
||||
"""Overlapping time window between datasets."""
|
||||
dataset_a: dict = field(default_factory=dict)
|
||||
dataset_b: dict = field(default_factory=dict)
|
||||
overlap_start: str = ""
|
||||
overlap_end: str = ""
|
||||
overlap_hours: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechniqueOverlap:
|
||||
"""Shared MITRE ATT&CK technique across hunts."""
|
||||
technique_id: str
|
||||
technique_name: str = ""
|
||||
hypotheses: list[dict] = field(default_factory=list)
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrelationResult:
|
||||
"""Complete correlation analysis result."""
|
||||
hunt_ids: list[str]
|
||||
ioc_overlaps: list[IOCOverlap] = field(default_factory=list)
|
||||
time_overlaps: list[TimeOverlap] = field(default_factory=list)
|
||||
technique_overlaps: list[TechniqueOverlap] = field(default_factory=list)
|
||||
host_overlaps: list[dict] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
total_correlations: int = 0
|
||||
|
||||
|
||||
class CorrelationEngine:
|
||||
"""Engine for finding correlations across hunts and datasets."""
|
||||
|
||||
async def correlate_hunts(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> CorrelationResult:
|
||||
"""Run full correlation analysis across specified hunts."""
|
||||
result = CorrelationResult(hunt_ids=hunt_ids)
|
||||
|
||||
# Run all correlation types
|
||||
result.ioc_overlaps = await self._find_ioc_overlaps(hunt_ids, db)
|
||||
result.time_overlaps = await self._find_time_overlaps(hunt_ids, db)
|
||||
result.technique_overlaps = await self._find_technique_overlaps(hunt_ids, db)
|
||||
result.host_overlaps = await self._find_host_overlaps(hunt_ids, db)
|
||||
|
||||
result.total_correlations = (
|
||||
len(result.ioc_overlaps)
|
||||
+ len(result.time_overlaps)
|
||||
+ len(result.technique_overlaps)
|
||||
+ len(result.host_overlaps)
|
||||
)
|
||||
|
||||
result.summary = self._build_summary(result)
|
||||
return result
|
||||
|
||||
async def correlate_all(self, db: AsyncSession) -> CorrelationResult:
|
||||
"""Correlate across ALL hunts in the system."""
|
||||
stmt = select(Hunt.id)
|
||||
result = await db.execute(stmt)
|
||||
hunt_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if len(hunt_ids) < 2:
|
||||
return CorrelationResult(
|
||||
hunt_ids=hunt_ids,
|
||||
summary="Need at least 2 hunts for correlation analysis.",
|
||||
)
|
||||
|
||||
return await self.correlate_hunts(hunt_ids, db)
|
||||
|
||||
async def find_ioc_across_hunts(
|
||||
self,
|
||||
ioc_value: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find all occurrences of a specific IOC across all datasets/hunts."""
|
||||
# Search in dataset rows using JSON contains
|
||||
stmt = select(DatasetRow, Dataset).join(
|
||||
Dataset, DatasetRow.dataset_id == Dataset.id
|
||||
)
|
||||
result = await db.execute(stmt.limit(5000))
|
||||
rows = result.all()
|
||||
|
||||
occurrences = []
|
||||
for row, dataset in rows:
|
||||
data = row.data or {}
|
||||
normalized = row.normalized_data or {}
|
||||
|
||||
# Search both raw and normalized data
|
||||
for col, val in {**data, **normalized}.items():
|
||||
if str(val) == ioc_value:
|
||||
occurrences.append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"row_index": row.row_index,
|
||||
"column": col,
|
||||
})
|
||||
break
|
||||
|
||||
return occurrences
|
||||
|
||||
# ── IOC overlap detection ─────────────────────────────────────────
|
||||
|
||||
async def _find_ioc_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[IOCOverlap]:
|
||||
"""Find IOC values that appear in datasets from different hunts."""
|
||||
# Get all datasets for the specified hunts
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
if len(datasets) < 2:
|
||||
return []
|
||||
|
||||
# Build IOC → dataset mapping
|
||||
ioc_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
if not dataset.ioc_columns:
|
||||
continue
|
||||
|
||||
ioc_cols = list(dataset.ioc_columns.keys())
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in ioc_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
ioc_map[str(val).strip()].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"column": col,
|
||||
"ioc_type": dataset.ioc_columns.get(col, "unknown"),
|
||||
})
|
||||
|
||||
# Filter to IOCs appearing in multiple hunts
|
||||
overlaps = []
|
||||
for ioc_value, appearances in ioc_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
# Check for enrichment data
|
||||
enrich_stmt = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value == ioc_value
|
||||
).limit(1)
|
||||
enrich_result = await db.execute(enrich_stmt)
|
||||
enrichment = enrich_result.scalar_one_or_none()
|
||||
|
||||
overlaps.append(IOCOverlap(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=appearances[0].get("ioc_type", "unknown"),
|
||||
datasets=appearances,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
count=len(appearances),
|
||||
enrichment_verdict=enrichment.verdict if enrichment else "",
|
||||
))
|
||||
|
||||
# Sort by count descending
|
||||
overlaps.sort(key=lambda x: x.count, reverse=True)
|
||||
return overlaps[:100] # Limit results
|
||||
|
||||
# ── Time window overlap ───────────────────────────────────────────
|
||||
|
||||
async def _find_time_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TimeOverlap]:
|
||||
"""Find datasets across hunts with overlapping time ranges."""
|
||||
stmt = select(Dataset).where(
|
||||
Dataset.hunt_id.in_(hunt_ids),
|
||||
Dataset.time_range_start.isnot(None),
|
||||
Dataset.time_range_end.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
overlaps = []
|
||||
for i, ds_a in enumerate(datasets):
|
||||
for ds_b in datasets[i + 1:]:
|
||||
if ds_a.hunt_id == ds_b.hunt_id:
|
||||
continue # Same hunt, skip
|
||||
|
||||
try:
|
||||
a_start = datetime.fromisoformat(ds_a.time_range_start)
|
||||
a_end = datetime.fromisoformat(ds_a.time_range_end)
|
||||
b_start = datetime.fromisoformat(ds_b.time_range_start)
|
||||
b_end = datetime.fromisoformat(ds_b.time_range_end)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Check overlap
|
||||
overlap_start = max(a_start, b_start)
|
||||
overlap_end = min(a_end, b_end)
|
||||
|
||||
if overlap_start < overlap_end:
|
||||
hours = (overlap_end - overlap_start).total_seconds() / 3600
|
||||
overlaps.append(TimeOverlap(
|
||||
dataset_a={
|
||||
"id": ds_a.id,
|
||||
"name": ds_a.name,
|
||||
"hunt_id": ds_a.hunt_id,
|
||||
"start": ds_a.time_range_start,
|
||||
"end": ds_a.time_range_end,
|
||||
},
|
||||
dataset_b={
|
||||
"id": ds_b.id,
|
||||
"name": ds_b.name,
|
||||
"hunt_id": ds_b.hunt_id,
|
||||
"start": ds_b.time_range_start,
|
||||
"end": ds_b.time_range_end,
|
||||
},
|
||||
overlap_start=overlap_start.isoformat(),
|
||||
overlap_end=overlap_end.isoformat(),
|
||||
overlap_hours=round(hours, 2),
|
||||
))
|
||||
|
||||
overlaps.sort(key=lambda x: x.overlap_hours, reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── MITRE technique overlap ───────────────────────────────────────
|
||||
|
||||
async def _find_technique_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TechniqueOverlap]:
|
||||
"""Find MITRE ATT&CK techniques shared across hunts."""
|
||||
stmt = select(Hypothesis).where(
|
||||
Hypothesis.hunt_id.in_(hunt_ids),
|
||||
Hypothesis.mitre_technique.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
hypotheses = result.scalars().all()
|
||||
|
||||
technique_map: dict[str, list[dict]] = defaultdict(list)
|
||||
for hyp in hypotheses:
|
||||
technique = hyp.mitre_technique.strip()
|
||||
if technique:
|
||||
technique_map[technique].append({
|
||||
"hypothesis_id": hyp.id,
|
||||
"hypothesis_title": hyp.title,
|
||||
"hunt_id": hyp.hunt_id,
|
||||
"status": hyp.status,
|
||||
})
|
||||
|
||||
overlaps = []
|
||||
for technique, hyps in technique_map.items():
|
||||
hunt_set = set(h["hunt_id"] for h in hyps if h["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append(TechniqueOverlap(
|
||||
technique_id=technique,
|
||||
hypotheses=hyps,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
))
|
||||
|
||||
return overlaps
|
||||
|
||||
# ── Host overlap ──────────────────────────────────────────────────
|
||||
|
||||
async def _find_host_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find hostnames that appear in datasets from different hunts.
|
||||
|
||||
Useful for detecting lateral movement patterns.
|
||||
"""
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
norm_cols = dataset.normalized_columns or {}
|
||||
# Look for hostname columns
|
||||
hostname_cols = [
|
||||
orig for orig, canon in norm_cols.items()
|
||||
if canon in ("hostname", "host", "computer_name", "src_host", "dst_host")
|
||||
]
|
||||
if not hostname_cols:
|
||||
continue
|
||||
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in hostname_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
host_name = str(val).strip().upper()
|
||||
host_map[host_name].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
})
|
||||
|
||||
# Filter to hosts appearing in multiple hunts
|
||||
overlaps = []
|
||||
for host, appearances in host_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append({
|
||||
"hostname": host,
|
||||
"hunt_ids": sorted(hunt_set),
|
||||
"dataset_count": len(appearances),
|
||||
"datasets": appearances[:10],
|
||||
})
|
||||
|
||||
overlaps.sort(key=lambda x: x["dataset_count"], reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── Summary builder ───────────────────────────────────────────────
|
||||
|
||||
def _build_summary(self, result: CorrelationResult) -> str:
|
||||
"""Build a human-readable summary of correlations."""
|
||||
parts = [f"Correlation analysis across {len(result.hunt_ids)} hunts:"]
|
||||
|
||||
if result.ioc_overlaps:
|
||||
malicious = [o for o in result.ioc_overlaps if o.enrichment_verdict == "malicious"]
|
||||
parts.append(
|
||||
f" - {len(result.ioc_overlaps)} shared IOCs "
|
||||
f"({len(malicious)} flagged malicious)"
|
||||
)
|
||||
else:
|
||||
parts.append(" - No shared IOCs found")
|
||||
|
||||
if result.time_overlaps:
|
||||
parts.append(f" - {len(result.time_overlaps)} overlapping time windows")
|
||||
|
||||
if result.technique_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.technique_overlaps)} shared MITRE techniques"
|
||||
)
|
||||
|
||||
if result.host_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.host_overlaps)} hosts appearing in multiple hunts "
|
||||
"(potential lateral movement)"
|
||||
)
|
||||
|
||||
if result.total_correlations == 0:
|
||||
parts.append(" No significant correlations detected.")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# Singleton
|
||||
correlation_engine = CorrelationEngine()
|
||||
165
backend/app/services/csv_parser.py
Normal file
165
backend/app/services/csv_parser.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""CSV parsing engine with encoding detection, delimiter sniffing, and streaming.
|
||||
|
||||
Handles large Velociraptor CSV exports with resilience to encoding issues,
|
||||
varied delimiters, and malformed rows.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
import chardet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reasonable defaults
|
||||
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB per field
|
||||
csv.field_size_limit(MAX_FIELD_SIZE)
|
||||
|
||||
|
||||
def detect_encoding(file_bytes: bytes, sample_size: int = 65536) -> str:
|
||||
"""Detect file encoding from a sample of bytes."""
|
||||
result = chardet.detect(file_bytes[:sample_size])
|
||||
encoding = result.get("encoding", "utf-8") or "utf-8"
|
||||
confidence = result.get("confidence", 0)
|
||||
logger.info(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||
# Fall back to utf-8 if confidence is very low
|
||||
if confidence < 0.5:
|
||||
encoding = "utf-8"
|
||||
return encoding
|
||||
|
||||
|
||||
def detect_delimiter(text_sample: str) -> str:
|
||||
"""Sniff the CSV delimiter from a text sample."""
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(text_sample, delimiters=",\t;|")
|
||||
return dialect.delimiter
|
||||
except csv.Error:
|
||||
return ","
|
||||
|
||||
|
||||
def infer_column_types(rows: list[dict], sample_size: int = 100) -> dict[str, str]:
|
||||
"""Infer column types from a sample of rows.
|
||||
|
||||
Returns a mapping of column_name → type_hint where type_hint is one of:
|
||||
timestamp, integer, float, ip, hash_md5, hash_sha1, hash_sha256, domain, path, string
|
||||
"""
|
||||
import re
|
||||
|
||||
type_map: dict[str, dict[str, int]] = {}
|
||||
sample = rows[:sample_size]
|
||||
|
||||
patterns = {
|
||||
"ip": re.compile(
|
||||
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
|
||||
),
|
||||
"hash_md5": re.compile(r"^[a-fA-F0-9]{32}$"),
|
||||
"hash_sha1": re.compile(r"^[a-fA-F0-9]{40}$"),
|
||||
"hash_sha256": re.compile(r"^[a-fA-F0-9]{64}$"),
|
||||
"integer": re.compile(r"^-?\d+$"),
|
||||
"float": re.compile(r"^-?\d+\.\d+$"),
|
||||
"timestamp": re.compile(
|
||||
r"^\d{4}[-/]\d{2}[-/]\d{2}[T ]\d{2}:\d{2}"
|
||||
),
|
||||
"domain": re.compile(
|
||||
r"^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?(\.[a-zA-Z]{2,})+$"
|
||||
),
|
||||
"path": re.compile(r"^([A-Z]:\\|/)", re.IGNORECASE),
|
||||
}
|
||||
|
||||
for row in sample:
|
||||
for col, val in row.items():
|
||||
if col not in type_map:
|
||||
type_map[col] = {}
|
||||
val_str = str(val).strip()
|
||||
if not val_str:
|
||||
continue
|
||||
matched = False
|
||||
for type_name, pattern in patterns.items():
|
||||
if pattern.match(val_str):
|
||||
type_map[col][type_name] = type_map[col].get(type_name, 0) + 1
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
type_map[col]["string"] = type_map[col].get("string", 0) + 1
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for col, counts in type_map.items():
|
||||
if counts:
|
||||
result[col] = max(counts, key=counts.get) # type: ignore[arg-type]
|
||||
else:
|
||||
result[col] = "string"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_csv_bytes(
|
||||
raw_bytes: bytes,
|
||||
max_rows: int | None = None,
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""Parse a CSV file from raw bytes.
|
||||
|
||||
Returns:
|
||||
(rows, metadata) where metadata contains encoding, delimiter, columns, etc.
|
||||
"""
|
||||
encoding = detect_encoding(raw_bytes)
|
||||
|
||||
try:
|
||||
text = raw_bytes.decode(encoding, errors="replace")
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
text = raw_bytes.decode("utf-8", errors="replace")
|
||||
encoding = "utf-8"
|
||||
|
||||
# Detect delimiter from first few KB
|
||||
delimiter = detect_delimiter(text[:8192])
|
||||
|
||||
reader = csv.DictReader(io.StringIO(text), delimiter=delimiter)
|
||||
columns = reader.fieldnames or []
|
||||
|
||||
rows: list[dict] = []
|
||||
for i, row in enumerate(reader):
|
||||
if max_rows is not None and i >= max_rows:
|
||||
break
|
||||
rows.append(dict(row))
|
||||
|
||||
column_types = infer_column_types(rows) if rows else {}
|
||||
|
||||
metadata = {
|
||||
"encoding": encoding,
|
||||
"delimiter": delimiter,
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
"row_count": len(rows),
|
||||
"total_rows_in_file": len(rows), # same when no max_rows
|
||||
}
|
||||
|
||||
return rows, metadata
|
||||
|
||||
|
||||
async def parse_csv_streaming(
|
||||
file_path: Path,
|
||||
chunk_size: int = 8192,
|
||||
) -> AsyncIterator[tuple[int, dict]]:
|
||||
"""Stream-parse a CSV file yielding (row_index, row_dict) tuples.
|
||||
|
||||
Memory-efficient for large files.
|
||||
"""
|
||||
import aiofiles # type: ignore[import-untyped]
|
||||
|
||||
# Read a sample for encoding/delimiter detection
|
||||
with open(file_path, "rb") as f:
|
||||
sample_bytes = f.read(65536)
|
||||
|
||||
encoding = detect_encoding(sample_bytes)
|
||||
text_sample = sample_bytes.decode(encoding, errors="replace")
|
||||
delimiter = detect_delimiter(text_sample[:8192])
|
||||
|
||||
# Now stream-read
|
||||
async with aiofiles.open(file_path, mode="r", encoding=encoding, errors="replace") as f:
|
||||
content = await f.read() # For DictReader compatibility
|
||||
|
||||
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
yield i, dict(row)
|
||||
655
backend/app/services/enrichment.py
Normal file
655
backend/app/services/enrichment.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""IOC Enrichment Engine — VirusTotal, AbuseIPDB, Shodan integrations.
|
||||
|
||||
Provides automated IOC enrichment with caching and rate limiting.
|
||||
Enriches IPs, hashes, domains with threat intelligence verdicts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db.models import EnrichmentResult as EnrichmentDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IOCType(str, Enum):
|
||||
IP = "ip"
|
||||
DOMAIN = "domain"
|
||||
HASH_MD5 = "hash_md5"
|
||||
HASH_SHA1 = "hash_sha1"
|
||||
HASH_SHA256 = "hash_sha256"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class Verdict(str, Enum):
|
||||
CLEAN = "clean"
|
||||
SUSPICIOUS = "suspicious"
|
||||
MALICIOUS = "malicious"
|
||||
UNKNOWN = "unknown"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichmentResultData:
|
||||
"""Enrichment result from a provider."""
|
||||
ioc_value: str
|
||||
ioc_type: IOCType
|
||||
source: str
|
||||
verdict: Verdict
|
||||
score: float = 0.0 # 0-100 normalized threat score
|
||||
raw_data: dict = field(default_factory=dict)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple token bucket rate limiter for API calls."""
|
||||
|
||||
def __init__(self, calls_per_minute: int = 4):
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.interval = 60.0 / calls_per_minute
|
||||
self._last_call: float = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_call
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
self._last_call = time.monotonic()
|
||||
|
||||
|
||||
# ── Provider base ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentProvider:
|
||||
"""Base class for enrichment providers."""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, api_key: str = "", rate_limit: int = 4):
|
||||
self.api_key = api_key
|
||||
self.rate_limiter = RateLimiter(rate_limit)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=5),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def cleanup(self):
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.api_key)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── VirusTotal ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class VirusTotalProvider(EnrichmentProvider):
|
||||
"""VirusTotal v3 API provider."""
|
||||
|
||||
name = "virustotal"
|
||||
BASE_URL = "https://www.virustotal.com/api/v3"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.VIRUSTOTAL_API_KEY, rate_limit=4)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {"x-apikey": self.api_key}
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="VirusTotal API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
endpoint = self._get_endpoint(ioc_value, ioc_type)
|
||||
if not endpoint:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"Unsupported IOC type: {ioc_type}",
|
||||
)
|
||||
|
||||
client = self._get_client()
|
||||
resp = await client.get(endpoint, headers=self._headers())
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
attrs = data.get("data", {}).get("attributes", {})
|
||||
stats = attrs.get("last_analysis_stats", {})
|
||||
|
||||
malicious = stats.get("malicious", 0)
|
||||
suspicious = stats.get("suspicious", 0)
|
||||
total = sum(stats.values()) if stats else 0
|
||||
|
||||
# Determine verdict
|
||||
if malicious > 3:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif malicious > 0 or suspicious > 2:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total > 0:
|
||||
verdict = Verdict.CLEAN
|
||||
else:
|
||||
verdict = Verdict.UNKNOWN
|
||||
|
||||
score = (malicious / total * 100) if total > 0 else 0
|
||||
|
||||
tags = attrs.get("tags", [])
|
||||
if attrs.get("type_description"):
|
||||
tags.append(attrs["type_description"])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=round(score, 1),
|
||||
raw_data={
|
||||
"stats": stats,
|
||||
"reputation": attrs.get("reputation", 0),
|
||||
"type_description": attrs.get("type_description", ""),
|
||||
"names": attrs.get("names", [])[:5],
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=attrs.get("country", ""),
|
||||
asn=str(attrs.get("asn", "")),
|
||||
org=attrs.get("as_owner", ""),
|
||||
last_seen=attrs.get("last_analysis_date", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"HTTP {e.response.status_code}",
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
def _get_endpoint(self, ioc_value: str, ioc_type: IOCType) -> str | None:
|
||||
if ioc_type == IOCType.IP:
|
||||
return f"{self.BASE_URL}/ip_addresses/{ioc_value}"
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
return f"{self.BASE_URL}/domains/{ioc_value}"
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
return f"{self.BASE_URL}/files/{ioc_value}"
|
||||
elif ioc_type == IOCType.URL:
|
||||
url_id = hashlib.sha256(ioc_value.encode()).hexdigest()
|
||||
return f"{self.BASE_URL}/urls/{url_id}"
|
||||
return None
|
||||
|
||||
|
||||
# ── AbuseIPDB ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AbuseIPDBProvider(EnrichmentProvider):
|
||||
"""AbuseIPDB API provider — IP reputation."""
|
||||
|
||||
name = "abuseipdb"
|
||||
BASE_URL = "https://api.abuseipdb.com/api/v2"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.ABUSEIPDB_API_KEY, rate_limit=10)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/check",
|
||||
params={"ipAddress": ioc_value, "maxAgeInDays": 90, "verbose": "true"},
|
||||
headers={"Key": self.api_key, "Accept": "application/json"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", {})
|
||||
|
||||
abuse_score = data.get("abuseConfidenceScore", 0)
|
||||
total_reports = data.get("totalReports", 0)
|
||||
|
||||
if abuse_score >= 75:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif abuse_score >= 25 or total_reports > 5:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total_reports == 0:
|
||||
verdict = Verdict.UNKNOWN
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
|
||||
categories = data.get("reports", [])
|
||||
tags = []
|
||||
for report in categories[:10]:
|
||||
for cat_id in report.get("categories", []):
|
||||
tag = self._category_name(cat_id)
|
||||
if tag and tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=float(abuse_score),
|
||||
raw_data={
|
||||
"abuse_confidence_score": abuse_score,
|
||||
"total_reports": total_reports,
|
||||
"is_whitelisted": data.get("isWhitelisted"),
|
||||
"is_tor": data.get("isTor", False),
|
||||
"usage_type": data.get("usageType", ""),
|
||||
"isp": data.get("isp", ""),
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=data.get("countryCode", ""),
|
||||
org=data.get("isp", ""),
|
||||
last_seen=data.get("lastReportedAt", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _category_name(cat_id: int) -> str:
|
||||
categories = {
|
||||
1: "DNS Compromise", 2: "DNS Poisoning", 3: "Fraud Orders",
|
||||
4: "DDoS Attack", 5: "FTP Brute-Force", 6: "Ping of Death",
|
||||
7: "Phishing", 8: "Fraud VoIP", 9: "Open Proxy",
|
||||
10: "Web Spam", 11: "Email Spam", 12: "Blog Spam",
|
||||
13: "VPN IP", 14: "Port Scan", 15: "Hacking",
|
||||
16: "SQL Injection", 17: "Spoofing", 18: "Brute-Force",
|
||||
19: "Bad Web Bot", 20: "Exploited Host", 21: "Web App Attack",
|
||||
22: "SSH", 23: "IoT Targeted",
|
||||
}
|
||||
return categories.get(cat_id, "")
|
||||
|
||||
|
||||
# ── Shodan ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ShodanProvider(EnrichmentProvider):
|
||||
"""Shodan API provider — infrastructure intelligence."""
|
||||
|
||||
name = "shodan"
|
||||
BASE_URL = "https://api.shodan.io"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.SHODAN_API_KEY, rate_limit=1)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/shodan/host/{ioc_value}",
|
||||
params={"key": self.api_key, "minify": "true"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
ports = data.get("ports", [])
|
||||
vulns = data.get("vulns", [])
|
||||
tags_raw = data.get("tags", [])
|
||||
|
||||
# Determine verdict based on open ports and vulns
|
||||
if vulns:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = min(len(vulns) * 15, 100.0)
|
||||
elif len(ports) > 20:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = 40.0
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
score = 0.0
|
||||
|
||||
tags = tags_raw[:10]
|
||||
if vulns:
|
||||
tags.extend([f"CVE: {v}" for v in vulns[:5]])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=score,
|
||||
raw_data={
|
||||
"ports": ports[:20],
|
||||
"vulns": vulns[:10],
|
||||
"os": data.get("os"),
|
||||
"hostnames": data.get("hostnames", [])[:5],
|
||||
"domains": data.get("domains", [])[:5],
|
||||
"last_update": data.get("last_update", ""),
|
||||
},
|
||||
tags=tags[:15],
|
||||
country=data.get("country_code", ""),
|
||||
asn=data.get("asn", ""),
|
||||
org=data.get("org", ""),
|
||||
last_seen=data.get("last_update", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Engine (orchestrator) ──────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentEngine:
|
||||
"""Orchestrates IOC enrichment across all providers with caching."""
|
||||
|
||||
CACHE_TTL_HOURS = 24
|
||||
|
||||
def __init__(self):
|
||||
self.providers: list[EnrichmentProvider] = [
|
||||
VirusTotalProvider(),
|
||||
AbuseIPDBProvider(),
|
||||
ShodanProvider(),
|
||||
]
|
||||
|
||||
@property
|
||||
def configured_providers(self) -> list[EnrichmentProvider]:
|
||||
return [p for p in self.providers if p.is_configured]
|
||||
|
||||
async def enrich_ioc(
|
||||
self,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
db: AsyncSession | None = None,
|
||||
skip_cache: bool = False,
|
||||
) -> list[EnrichmentResultData]:
|
||||
"""Enrich a single IOC across all configured providers.
|
||||
|
||||
Uses cached results from DB when available.
|
||||
"""
|
||||
results: list[EnrichmentResultData] = []
|
||||
|
||||
# Check cache first
|
||||
if db and not skip_cache:
|
||||
cached = await self._get_cached(db, ioc_value, ioc_type)
|
||||
if cached:
|
||||
logger.info(f"Cache hit for {ioc_type.value}:{ioc_value} ({len(cached)} results)")
|
||||
return cached
|
||||
|
||||
# Query all applicable providers in parallel
|
||||
tasks = []
|
||||
for provider in self.configured_providers:
|
||||
# Skip providers that don't support this IOC type
|
||||
if ioc_type in (IOCType.DOMAIN,) and provider.name in ("abuseipdb", "shodan"):
|
||||
continue
|
||||
if ioc_type == IOCType.IP and provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.IP:
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.URL:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
|
||||
if tasks:
|
||||
results = list(await asyncio.gather(*tasks, return_exceptions=False))
|
||||
|
||||
# Cache results
|
||||
if db and results:
|
||||
await self._cache_results(db, results)
|
||||
|
||||
return results
|
||||
|
||||
async def enrich_batch(
|
||||
self,
|
||||
iocs: list[tuple[str, IOCType]],
|
||||
db: AsyncSession | None = None,
|
||||
concurrency: int = 3,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Enrich a batch of IOCs with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
all_results: dict[str, list[EnrichmentResultData]] = {}
|
||||
|
||||
async def _enrich_one(value: str, ioc_type: IOCType):
|
||||
async with sem:
|
||||
result = await self.enrich_ioc(value, ioc_type, db=db)
|
||||
all_results[value] = result
|
||||
|
||||
tasks = [_enrich_one(v, t) for v, t in iocs]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return all_results
|
||||
|
||||
async def enrich_dataset_iocs(
|
||||
self,
|
||||
rows: list[dict],
|
||||
ioc_columns: dict,
|
||||
db: AsyncSession | None = None,
|
||||
max_iocs: int = 50,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Auto-enrich IOCs found in a dataset.
|
||||
|
||||
Extracts unique IOC values from the identified columns and enriches them.
|
||||
"""
|
||||
iocs_to_enrich: list[tuple[str, IOCType]] = []
|
||||
seen = set()
|
||||
|
||||
for col_name, col_type in ioc_columns.items():
|
||||
ioc_type = self._map_column_type(col_type)
|
||||
if not ioc_type:
|
||||
continue
|
||||
|
||||
for row in rows:
|
||||
value = row.get(col_name, "")
|
||||
if value and value not in seen:
|
||||
seen.add(value)
|
||||
iocs_to_enrich.append((str(value), ioc_type))
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if iocs_to_enrich:
|
||||
return await self.enrich_batch(iocs_to_enrich, db=db)
|
||||
return {}
|
||||
|
||||
async def _get_cached(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
) -> list[EnrichmentResultData] | None:
|
||||
"""Check for cached enrichment results."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=self.CACHE_TTL_HOURS)
|
||||
stmt = (
|
||||
select(EnrichmentDB)
|
||||
.where(
|
||||
EnrichmentDB.ioc_value == ioc_value,
|
||||
EnrichmentDB.ioc_type == ioc_type.value,
|
||||
EnrichmentDB.cached_at >= cutoff,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
cached = result.scalars().all()
|
||||
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
return [
|
||||
EnrichmentResultData(
|
||||
ioc_value=c.ioc_value,
|
||||
ioc_type=IOCType(c.ioc_type),
|
||||
source=c.source,
|
||||
verdict=Verdict(c.verdict),
|
||||
score=c.score or 0.0,
|
||||
raw_data=c.raw_data or {},
|
||||
tags=c.tags or [],
|
||||
country=c.country or "",
|
||||
asn=c.asn or "",
|
||||
org=c.org or "",
|
||||
)
|
||||
for c in cached
|
||||
]
|
||||
|
||||
async def _cache_results(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
results: list[EnrichmentResultData],
|
||||
):
|
||||
"""Cache enrichment results in the database."""
|
||||
for r in results:
|
||||
if r.verdict == Verdict.ERROR:
|
||||
continue # Don't cache errors
|
||||
entry = EnrichmentDB(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
raw_data=r.raw_data,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
)
|
||||
db.add(entry)
|
||||
try:
|
||||
await db.flush()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache enrichment: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _map_column_type(col_type: str) -> IOCType | None:
|
||||
"""Map column type from normalizer to IOCType."""
|
||||
mapping = {
|
||||
"ip": IOCType.IP,
|
||||
"ip_address": IOCType.IP,
|
||||
"src_ip": IOCType.IP,
|
||||
"dst_ip": IOCType.IP,
|
||||
"domain": IOCType.DOMAIN,
|
||||
"hash_md5": IOCType.HASH_MD5,
|
||||
"hash_sha1": IOCType.HASH_SHA1,
|
||||
"hash_sha256": IOCType.HASH_SHA256,
|
||||
"url": IOCType.URL,
|
||||
}
|
||||
return mapping.get(col_type)
|
||||
|
||||
async def cleanup(self):
|
||||
for provider in self.providers:
|
||||
await provider.cleanup()
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return enrichment engine status."""
|
||||
return {
|
||||
"providers": {
|
||||
p.name: {"configured": p.is_configured}
|
||||
for p in self.providers
|
||||
},
|
||||
"cache_ttl_hours": self.CACHE_TTL_HOURS,
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
enrichment_engine = EnrichmentEngine()
|
||||
145
backend/app/services/keyword_defaults.py
Normal file
145
backend/app/services/keyword_defaults.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Default AUP keyword themes and their seed keywords.
|
||||
|
||||
Called once on startup — only inserts themes that don't already exist,
|
||||
so user edits are never overwritten.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Default themes + keywords ─────────────────────────────────────────
|
||||
|
||||
DEFAULTS: dict[str, dict] = {
|
||||
"Gambling": {
|
||||
"color": "#f44336",
|
||||
"keywords": [
|
||||
"poker", "casino", "blackjack", "roulette", "sportsbook",
|
||||
"sports betting", "bet365", "draftkings", "fanduel", "bovada",
|
||||
"betonline", "mybookie", "slots", "slot machine", "parlay",
|
||||
"wager", "bookie", "betway", "888casino", "pokerstars",
|
||||
"william hill", "ladbrokes", "betfair", "unibet", "pinnacle",
|
||||
],
|
||||
},
|
||||
"Gaming": {
|
||||
"color": "#9c27b0",
|
||||
"keywords": [
|
||||
"steam", "steamcommunity", "steampowered", "epic games",
|
||||
"epicgames", "origin.com", "battle.net", "blizzard",
|
||||
"roblox", "minecraft", "fortnite", "valorant", "league of legends",
|
||||
"twitch", "twitch.tv", "discord", "discord.gg", "xbox live",
|
||||
"playstation network", "gog.com", "itch.io", "gamepass",
|
||||
"riot games", "ubisoft", "ea.com",
|
||||
],
|
||||
},
|
||||
"Streaming": {
|
||||
"color": "#ff9800",
|
||||
"keywords": [
|
||||
"netflix", "hulu", "disney+", "disneyplus", "hbomax",
|
||||
"amazon prime video", "peacock", "paramount+", "crunchyroll",
|
||||
"funimation", "spotify", "pandora", "soundcloud", "deezer",
|
||||
"tidal", "apple music", "youtube music", "pluto tv",
|
||||
"tubi", "vudu", "plex",
|
||||
],
|
||||
},
|
||||
"Downloads / Piracy": {
|
||||
"color": "#ff5722",
|
||||
"keywords": [
|
||||
"torrent", "bittorrent", "utorrent", "qbittorrent", "piratebay",
|
||||
"thepiratebay", "1337x", "rarbg", "yts", "kickass",
|
||||
"limewire", "frostwire", "mega.nz", "rapidshare", "mediafire",
|
||||
"zippyshare", "uploadhaven", "fitgirl", "repack", "crack",
|
||||
"keygen", "warez", "nulled", "pirate", "magnet:",
|
||||
],
|
||||
},
|
||||
"Adult Content": {
|
||||
"color": "#e91e63",
|
||||
"keywords": [
|
||||
"pornhub", "xvideos", "xhamster", "onlyfans", "chaturbate",
|
||||
"livejasmin", "brazzers", "redtube", "youporn", "xnxx",
|
||||
"porn", "xxx", "nsfw", "adult content", "cam site",
|
||||
"stripchat", "bongacams",
|
||||
],
|
||||
},
|
||||
"Social Media": {
|
||||
"color": "#2196f3",
|
||||
"keywords": [
|
||||
"facebook", "instagram", "tiktok", "snapchat", "pinterest",
|
||||
"reddit", "tumblr", "myspace", "whatsapp web", "telegram web",
|
||||
"signal web", "wechat web", "twitter.com", "x.com",
|
||||
"threads.net", "mastodon", "bluesky",
|
||||
],
|
||||
},
|
||||
"Job Search": {
|
||||
"color": "#4caf50",
|
||||
"keywords": [
|
||||
"indeed", "linkedin jobs", "glassdoor", "monster.com",
|
||||
"ziprecruiter", "careerbuilder", "dice.com", "hired.com",
|
||||
"angel.co", "wellfound", "levels.fyi", "salary.com",
|
||||
"payscale", "resume", "cover letter", "job application",
|
||||
],
|
||||
},
|
||||
"Shopping": {
|
||||
"color": "#00bcd4",
|
||||
"keywords": [
|
||||
"amazon.com", "ebay", "etsy", "walmart.com", "target.com",
|
||||
"bestbuy", "aliexpress", "wish.com", "shein", "temu",
|
||||
"wayfair", "overstock", "newegg", "zappos", "coupon",
|
||||
"promo code", "add to cart",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def seed_defaults(db: AsyncSession) -> int:
|
||||
"""Insert default themes + keywords for any theme name not already in DB.
|
||||
|
||||
Returns the number of themes inserted (0 if all already exist).
|
||||
"""
|
||||
# Rename legacy theme names
|
||||
_renames = [("Social Media (Personal)", "Social Media")]
|
||||
for old_name, new_name in _renames:
|
||||
old = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == old_name))
|
||||
if old:
|
||||
await db.execute(
|
||||
KeywordTheme.__table__.update()
|
||||
.where(KeywordTheme.name == old_name)
|
||||
.values(name=new_name)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info("Renamed AUP theme '%s' → '%s'", old_name, new_name)
|
||||
|
||||
inserted = 0
|
||||
for theme_name, meta in DEFAULTS.items():
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == theme_name)
|
||||
)
|
||||
if exists:
|
||||
continue
|
||||
|
||||
theme = KeywordTheme(
|
||||
name=theme_name,
|
||||
color=meta["color"],
|
||||
enabled=True,
|
||||
is_builtin=True,
|
||||
)
|
||||
db.add(theme)
|
||||
await db.flush() # get theme.id
|
||||
|
||||
for kw in meta["keywords"]:
|
||||
db.add(Keyword(theme_id=theme.id, value=kw))
|
||||
|
||||
inserted += 1
|
||||
logger.info("Seeded AUP theme '%s' with %d keywords", theme_name, len(meta["keywords"]))
|
||||
|
||||
if inserted:
|
||||
await db.commit()
|
||||
logger.info("Seeded %d AUP keyword themes", inserted)
|
||||
else:
|
||||
logger.debug("All default AUP themes already present")
|
||||
|
||||
return inserted
|
||||
196
backend/app/services/normalizer.py
Normal file
196
backend/app/services/normalizer.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Artifact normalizer — maps Velociraptor and common tool columns to canonical schema.
|
||||
|
||||
The canonical schema provides consistent field names regardless of which tool
|
||||
exported the CSV (Velociraptor, OSQuery, Sysmon, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Column mapping: source_column_pattern → canonical_name ─────────────
|
||||
# Patterns are case-insensitive regexes matched against column names.
|
||||
|
||||
COLUMN_MAPPINGS: list[tuple[str, str]] = [
|
||||
# Timestamps
|
||||
(r"^(timestamp|time|event_?time|date_?time|created?_?(at|time|date)|modified_?(at|time|date)|mtime|ctime|atime|start_?time|end_?time)$", "timestamp"),
|
||||
(r"^(eventtime|system\.timecreated)$", "timestamp"),
|
||||
# Host identifiers
|
||||
(r"^(hostname|host|fqdn|computer_?name|system_?name|machinename|clientid)$", "hostname"),
|
||||
# Operating system
|
||||
(r"^(os|operating_?system|os_?version|os_?name|platform|os_?type)$", "os"),
|
||||
# Source / destination IPs
|
||||
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress)$", "src_ip"),
|
||||
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress)$", "dst_ip"),
|
||||
(r"^(ip_?address|ipaddress|ip)$", "ip_address"),
|
||||
# Ports
|
||||
(r"^(source_?port|src_?port|localport)$", "src_port"),
|
||||
(r"^(dest_?port|dst_?port|remoteport|destinationport)$", "dst_port"),
|
||||
# Process info
|
||||
(r"^(process_?name|name|image|exe|executable|binary)$", "process_name"),
|
||||
(r"^(pid|process_?id)$", "pid"),
|
||||
(r"^(ppid|parent_?pid|parentprocessid)$", "ppid"),
|
||||
(r"^(command_?line|cmdline|commandline|cmd)$", "command_line"),
|
||||
(r"^(parent_?command_?line|parentcommandline)$", "parent_command_line"),
|
||||
# User info
|
||||
(r"^(user|username|user_?name|account_?name|subjectusername)$", "username"),
|
||||
(r"^(user_?id|uid|sid|subjectusersid)$", "user_id"),
|
||||
# File info
|
||||
(r"^(file_?path|fullpath|full_?name|path|filepath)$", "file_path"),
|
||||
(r"^(file_?name|filename|name)$", "file_name"),
|
||||
(r"^(file_?size|size|bytes|length)$", "file_size"),
|
||||
(r"^(extension|file_?ext)$", "file_extension"),
|
||||
# Hashes
|
||||
(r"^(md5|md5hash|hash_?md5)$", "hash_md5"),
|
||||
(r"^(sha1|sha1hash|hash_?sha1)$", "hash_sha1"),
|
||||
(r"^(sha256|sha256hash|hash_?sha256|hash|filehash)$", "hash_sha256"),
|
||||
# Network
|
||||
(r"^(protocol|proto)$", "protocol"),
|
||||
(r"^(domain|dns_?name|query_?name|queriedname)$", "domain"),
|
||||
(r"^(url|uri|request_?url)$", "url"),
|
||||
# Event info
|
||||
(r"^(event_?id|eventid|eid)$", "event_id"),
|
||||
(r"^(event_?type|eventtype|category|action)$", "event_type"),
|
||||
(r"^(description|message|msg|detail)$", "description"),
|
||||
(r"^(severity|level|priority)$", "severity"),
|
||||
# Registry
|
||||
(r"^(reg_?key|registry_?key|targetobject)$", "registry_key"),
|
||||
(r"^(reg_?value|registry_?value)$", "registry_value"),
|
||||
]
|
||||
|
||||
|
||||
def normalize_columns(columns: list[str]) -> dict[str, str]:
|
||||
"""Map raw column names to canonical names.
|
||||
|
||||
Returns:
|
||||
Dict of {raw_column_name: canonical_column_name}.
|
||||
Columns with no match map to themselves (lowered + underscored).
|
||||
"""
|
||||
mapping: dict[str, str] = {}
|
||||
used_canonical: set[str] = set()
|
||||
|
||||
for col in columns:
|
||||
col_lower = col.strip().lower()
|
||||
matched = False
|
||||
for pattern, canonical in COLUMN_MAPPINGS:
|
||||
if re.match(pattern, col_lower, re.IGNORECASE):
|
||||
# Avoid duplicate canonical names
|
||||
if canonical not in used_canonical:
|
||||
mapping[col] = canonical
|
||||
used_canonical.add(canonical)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
# Produce a clean snake_case version
|
||||
clean = re.sub(r"[^a-z0-9]+", "_", col_lower).strip("_")
|
||||
mapping[col] = clean or col
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def normalize_row(row: dict[str, Any], column_mapping: dict[str, str]) -> dict[str, Any]:
|
||||
"""Apply column mapping to a single row."""
|
||||
return {column_mapping.get(k, k): v for k, v in row.items()}
|
||||
|
||||
|
||||
def normalize_rows(rows: list[dict], column_mapping: dict[str, str]) -> list[dict]:
|
||||
"""Apply column mapping to all rows."""
|
||||
return [normalize_row(row, column_mapping) for row in rows]
|
||||
|
||||
|
||||
def detect_ioc_columns(
|
||||
columns: list[str],
|
||||
column_types: dict[str, str],
|
||||
column_mapping: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
"""Detect which columns contain IOCs (IPs, hashes, domains).
|
||||
|
||||
Returns:
|
||||
Dict of {column_name: ioc_type}.
|
||||
"""
|
||||
ioc_columns: dict[str, str] = {}
|
||||
ioc_type_map = {
|
||||
"ip": "ip",
|
||||
"hash_md5": "hash_md5",
|
||||
"hash_sha1": "hash_sha1",
|
||||
"hash_sha256": "hash_sha256",
|
||||
"domain": "domain",
|
||||
}
|
||||
|
||||
for col in columns:
|
||||
col_type = column_types.get(col)
|
||||
if col_type in ioc_type_map:
|
||||
ioc_columns[col] = ioc_type_map[col_type]
|
||||
|
||||
# Also check canonical name
|
||||
canonical = column_mapping.get(col, "")
|
||||
if canonical in ("src_ip", "dst_ip", "ip_address"):
|
||||
ioc_columns[col] = "ip"
|
||||
elif canonical == "hash_md5":
|
||||
ioc_columns[col] = "hash_md5"
|
||||
elif canonical == "hash_sha1":
|
||||
ioc_columns[col] = "hash_sha1"
|
||||
elif canonical in ("hash_sha256",):
|
||||
ioc_columns[col] = "hash_sha256"
|
||||
elif canonical == "domain":
|
||||
ioc_columns[col] = "domain"
|
||||
elif canonical == "url":
|
||||
ioc_columns[col] = "url"
|
||||
|
||||
return ioc_columns
|
||||
|
||||
|
||||
def detect_time_range(
|
||||
rows: list[dict],
|
||||
column_mapping: dict[str, str],
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""Find the earliest and latest timestamps in the dataset."""
|
||||
ts_col = None
|
||||
for raw_col, canonical in column_mapping.items():
|
||||
if canonical == "timestamp":
|
||||
ts_col = raw_col
|
||||
break
|
||||
|
||||
if not ts_col:
|
||||
return None, None
|
||||
|
||||
timestamps: list[datetime] = []
|
||||
for row in rows:
|
||||
val = row.get(ts_col)
|
||||
if not val:
|
||||
continue
|
||||
try:
|
||||
dt = _parse_timestamp(str(val))
|
||||
if dt:
|
||||
timestamps.append(dt)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if not timestamps:
|
||||
return None, None
|
||||
|
||||
return min(timestamps), max(timestamps)
|
||||
|
||||
|
||||
def _parse_timestamp(value: str) -> datetime | None:
|
||||
"""Try multiple timestamp formats."""
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ",
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%m/%d/%Y %H:%M:%S",
|
||||
"%d/%m/%Y %H:%M:%S",
|
||||
]
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(value.strip(), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
425
backend/app/services/reports.py
Normal file
425
backend/app/services/reports.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Report generation — JSON, HTML, and CSV export for hunt investigations.
|
||||
|
||||
Generates comprehensive investigation reports including:
|
||||
- Hunt metadata and status
|
||||
- Dataset summaries with IOC counts
|
||||
- Hypotheses and their evidence
|
||||
- Annotations timeline
|
||||
- Enrichment verdicts
|
||||
- Agent conversation history
|
||||
- Cross-hunt correlations
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
Hunt, Dataset, DatasetRow, Hypothesis,
|
||||
Annotation, Conversation, Message, EnrichmentResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
"""Generates exportable investigation reports."""
|
||||
|
||||
async def generate_hunt_report(
|
||||
self,
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
format: str = "json",
|
||||
include_rows: bool = False,
|
||||
max_rows: int = 500,
|
||||
) -> dict | str:
|
||||
"""Generate a comprehensive report for a hunt investigation."""
|
||||
|
||||
# Gather all hunt data
|
||||
report_data = await self._gather_hunt_data(
|
||||
hunt_id, db, include_rows=include_rows, max_rows=max_rows,
|
||||
)
|
||||
|
||||
if not report_data:
|
||||
return {"error": "Hunt not found"}
|
||||
|
||||
if format == "json":
|
||||
return report_data
|
||||
elif format == "html":
|
||||
return self._render_html(report_data)
|
||||
elif format == "csv":
|
||||
return self._render_csv(report_data)
|
||||
else:
|
||||
return report_data
|
||||
|
||||
async def _gather_hunt_data(
|
||||
self,
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
include_rows: bool = False,
|
||||
max_rows: int = 500,
|
||||
) -> dict | None:
|
||||
"""Gather all data for a hunt report."""
|
||||
|
||||
# Hunt metadata
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
return None
|
||||
|
||||
# Datasets
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_result.scalars().all()
|
||||
|
||||
dataset_summaries = []
|
||||
all_iocs = {}
|
||||
for ds in datasets:
|
||||
summary = {
|
||||
"id": ds.id,
|
||||
"name": ds.name,
|
||||
"filename": ds.filename,
|
||||
"source_tool": ds.source_tool,
|
||||
"row_count": ds.row_count,
|
||||
"columns": list((ds.column_schema or {}).keys()),
|
||||
"ioc_columns": ds.ioc_columns or {},
|
||||
"time_range": {
|
||||
"start": ds.time_range_start,
|
||||
"end": ds.time_range_end,
|
||||
},
|
||||
"created_at": ds.created_at.isoformat(),
|
||||
}
|
||||
|
||||
if include_rows:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(max_rows)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
summary["rows"] = [r.data for r in rows]
|
||||
|
||||
dataset_summaries.append(summary)
|
||||
|
||||
# Collect IOCs for enrichment lookup
|
||||
if ds.ioc_columns:
|
||||
all_iocs.update(ds.ioc_columns)
|
||||
|
||||
# Hypotheses
|
||||
hyp_result = await db.execute(
|
||||
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||
)
|
||||
hypotheses = hyp_result.scalars().all()
|
||||
|
||||
hypotheses_data = [
|
||||
{
|
||||
"id": h.id,
|
||||
"title": h.title,
|
||||
"description": h.description,
|
||||
"mitre_technique": h.mitre_technique,
|
||||
"status": h.status,
|
||||
"evidence_row_ids": h.evidence_row_ids,
|
||||
"evidence_notes": h.evidence_notes,
|
||||
"created_at": h.created_at.isoformat(),
|
||||
"updated_at": h.updated_at.isoformat(),
|
||||
}
|
||||
for h in hypotheses
|
||||
]
|
||||
|
||||
# Annotations (across all datasets in this hunt)
|
||||
dataset_ids = [ds.id for ds in datasets]
|
||||
annotations_data = []
|
||||
if dataset_ids:
|
||||
ann_result = await db.execute(
|
||||
select(Annotation)
|
||||
.where(Annotation.dataset_id.in_(dataset_ids))
|
||||
.order_by(Annotation.created_at)
|
||||
)
|
||||
annotations = ann_result.scalars().all()
|
||||
annotations_data = [
|
||||
{
|
||||
"id": a.id,
|
||||
"dataset_id": a.dataset_id,
|
||||
"row_id": a.row_id,
|
||||
"text": a.text,
|
||||
"severity": a.severity,
|
||||
"tag": a.tag,
|
||||
"created_at": a.created_at.isoformat(),
|
||||
}
|
||||
for a in annotations
|
||||
]
|
||||
|
||||
# Conversations
|
||||
conv_result = await db.execute(
|
||||
select(Conversation).where(Conversation.hunt_id == hunt_id)
|
||||
)
|
||||
conversations = conv_result.scalars().all()
|
||||
|
||||
conversations_data = []
|
||||
for conv in conversations:
|
||||
msg_result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at)
|
||||
)
|
||||
messages = msg_result.scalars().all()
|
||||
conversations_data.append({
|
||||
"id": conv.id,
|
||||
"title": conv.title,
|
||||
"messages": [
|
||||
{
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"model_used": m.model_used,
|
||||
"node_used": m.node_used,
|
||||
"latency_ms": m.latency_ms,
|
||||
"created_at": m.created_at.isoformat(),
|
||||
}
|
||||
for m in messages
|
||||
],
|
||||
})
|
||||
|
||||
# Enrichment results
|
||||
enrichment_data = []
|
||||
for ds in datasets:
|
||||
if not ds.ioc_columns:
|
||||
continue
|
||||
# Get unique enriched IOCs for this dataset
|
||||
for col_name in ds.ioc_columns.keys():
|
||||
enrich_result = await db.execute(
|
||||
select(EnrichmentResult)
|
||||
.where(EnrichmentResult.source.isnot(None))
|
||||
.limit(100)
|
||||
)
|
||||
enrichments = enrich_result.scalars().all()
|
||||
for e in enrichments:
|
||||
enrichment_data.append({
|
||||
"ioc_value": e.ioc_value,
|
||||
"ioc_type": e.ioc_type,
|
||||
"source": e.source,
|
||||
"verdict": e.verdict,
|
||||
"score": e.score,
|
||||
"tags": e.tags,
|
||||
"country": e.country,
|
||||
})
|
||||
break # Only query once
|
||||
|
||||
# Build report
|
||||
now = datetime.now(timezone.utc)
|
||||
return {
|
||||
"report_metadata": {
|
||||
"generated_at": now.isoformat(),
|
||||
"format_version": "1.0",
|
||||
"generator": "ThreatHunt Report Engine",
|
||||
},
|
||||
"hunt": {
|
||||
"id": hunt.id,
|
||||
"name": hunt.name,
|
||||
"description": hunt.description,
|
||||
"status": hunt.status,
|
||||
"created_at": hunt.created_at.isoformat(),
|
||||
"updated_at": hunt.updated_at.isoformat(),
|
||||
},
|
||||
"summary": {
|
||||
"dataset_count": len(datasets),
|
||||
"total_rows": sum(ds.row_count for ds in datasets),
|
||||
"hypothesis_count": len(hypotheses),
|
||||
"confirmed_hypotheses": len([h for h in hypotheses if h.status == "confirmed"]),
|
||||
"annotation_count": len(annotations_data),
|
||||
"critical_annotations": len([a for a in annotations_data if a["severity"] == "critical"]),
|
||||
"conversation_count": len(conversations_data),
|
||||
"enrichment_count": len(enrichment_data),
|
||||
"malicious_iocs": len([e for e in enrichment_data if e["verdict"] == "malicious"]),
|
||||
},
|
||||
"datasets": dataset_summaries,
|
||||
"hypotheses": hypotheses_data,
|
||||
"annotations": annotations_data,
|
||||
"conversations": conversations_data,
|
||||
"enrichments": enrichment_data[:100],
|
||||
}
|
||||
|
||||
def _render_html(self, data: dict) -> str:
|
||||
"""Render report as self-contained HTML."""
|
||||
hunt = data.get("hunt", {})
|
||||
summary = data.get("summary", {})
|
||||
hypotheses = data.get("hypotheses", [])
|
||||
annotations = data.get("annotations", [])
|
||||
datasets = data.get("datasets", [])
|
||||
enrichments = data.get("enrichments", [])
|
||||
meta = data.get("report_metadata", {})
|
||||
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>ThreatHunt Report: {hunt.get('name', 'Unknown')}</title>
|
||||
<style>
|
||||
:root {{ --bg: #0d1117; --surface: #161b22; --border: #30363d; --text: #c9d1d9; --accent: #58a6ff; --red: #f85149; --orange: #d29922; --green: #3fb950; }}
|
||||
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif; background: var(--bg); color: var(--text); line-height: 1.6; padding: 2rem; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; }}
|
||||
h1 {{ color: var(--accent); border-bottom: 2px solid var(--border); padding-bottom: 0.5rem; margin-bottom: 1rem; }}
|
||||
h2 {{ color: var(--accent); margin: 1.5rem 0 0.75rem; }}
|
||||
h3 {{ color: var(--text); margin: 1rem 0 0.5rem; }}
|
||||
.card {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; margin: 0.75rem 0; }}
|
||||
.stat-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 0.75rem; }}
|
||||
.stat {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; text-align: center; }}
|
||||
.stat .value {{ font-size: 2rem; font-weight: 700; color: var(--accent); }}
|
||||
.stat .label {{ font-size: 0.85rem; color: #8b949e; }}
|
||||
table {{ width: 100%; border-collapse: collapse; margin: 0.5rem 0; }}
|
||||
th, td {{ padding: 0.5rem 0.75rem; border: 1px solid var(--border); text-align: left; }}
|
||||
th {{ background: var(--surface); color: var(--accent); }}
|
||||
.badge {{ display: inline-block; padding: 0.15rem 0.5rem; border-radius: 999px; font-size: 0.8rem; font-weight: 600; }}
|
||||
.badge-malicious {{ background: var(--red); color: white; }}
|
||||
.badge-suspicious {{ background: var(--orange); color: #000; }}
|
||||
.badge-clean {{ background: var(--green); color: #000; }}
|
||||
.badge-critical {{ background: var(--red); color: white; }}
|
||||
.badge-high {{ background: #da3633; color: white; }}
|
||||
.badge-medium {{ background: var(--orange); color: #000; }}
|
||||
.badge-confirmed {{ background: var(--green); color: #000; }}
|
||||
.badge-active {{ background: var(--accent); color: #000; }}
|
||||
.footer {{ margin-top: 2rem; padding-top: 1rem; border-top: 1px solid var(--border); color: #8b949e; font-size: 0.85rem; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔍 ThreatHunt Report: {hunt.get('name', 'Untitled')}</h1>
|
||||
<p><strong>Hunt ID:</strong> {hunt.get('id', '')}<br>
|
||||
<strong>Status:</strong> {hunt.get('status', 'unknown')}<br>
|
||||
<strong>Description:</strong> {hunt.get('description', 'N/A')}<br>
|
||||
<strong>Created:</strong> {hunt.get('created_at', '')}</p>
|
||||
|
||||
<h2>Summary</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat"><div class="value">{summary.get('dataset_count', 0)}</div><div class="label">Datasets</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('total_rows', 0):,}</div><div class="label">Total Rows</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('hypothesis_count', 0)}</div><div class="label">Hypotheses</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('confirmed_hypotheses', 0)}</div><div class="label">Confirmed</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('annotation_count', 0)}</div><div class="label">Annotations</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('malicious_iocs', 0)}</div><div class="label">Malicious IOCs</div></div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Hypotheses section
|
||||
if hypotheses:
|
||||
html += "<h2>Hypotheses</h2>\n"
|
||||
html += "<table><tr><th>Title</th><th>MITRE</th><th>Status</th><th>Description</th></tr>\n"
|
||||
for h in hypotheses:
|
||||
status_class = f"badge-{h['status']}" if h['status'] in ('confirmed', 'active') else ""
|
||||
html += (
|
||||
f"<tr><td>{h['title']}</td>"
|
||||
f"<td>{h.get('mitre_technique', 'N/A')}</td>"
|
||||
f"<td><span class='badge {status_class}'>{h['status']}</span></td>"
|
||||
f"<td>{h.get('description', '') or ''}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
# Datasets section
|
||||
if datasets:
|
||||
html += "<h2>Datasets</h2>\n"
|
||||
for ds in datasets:
|
||||
html += f"""<div class="card">
|
||||
<h3>{ds['name']} ({ds.get('filename', '')})</h3>
|
||||
<p><strong>Source:</strong> {ds.get('source_tool', 'N/A')} |
|
||||
<strong>Rows:</strong> {ds['row_count']:,} |
|
||||
<strong>IOC Columns:</strong> {len(ds.get('ioc_columns', {}))} |
|
||||
<strong>Time Range:</strong> {ds.get('time_range', {}).get('start', 'N/A')} to {ds.get('time_range', {}).get('end', 'N/A')}</p>
|
||||
</div>\n"""
|
||||
|
||||
# Annotations
|
||||
if annotations:
|
||||
critical = [a for a in annotations if a['severity'] in ('critical', 'high')]
|
||||
html += f"<h2>Annotations ({len(annotations)} total, {len(critical)} critical/high)</h2>\n"
|
||||
html += "<table><tr><th>Severity</th><th>Tag</th><th>Text</th><th>Created</th></tr>\n"
|
||||
for a in annotations[:50]:
|
||||
sev_class = f"badge-{a['severity']}" if a['severity'] in ('critical', 'high', 'medium') else ""
|
||||
html += (
|
||||
f"<tr><td><span class='badge {sev_class}'>{a['severity']}</span></td>"
|
||||
f"<td>{a.get('tag', 'N/A')}</td>"
|
||||
f"<td>{a['text'][:200]}</td>"
|
||||
f"<td>{a['created_at'][:19]}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
# Enrichments
|
||||
if enrichments:
|
||||
malicious = [e for e in enrichments if e['verdict'] == 'malicious']
|
||||
html += f"<h2>IOC Enrichment ({len(enrichments)} results, {len(malicious)} malicious)</h2>\n"
|
||||
html += "<table><tr><th>IOC</th><th>Type</th><th>Source</th><th>Verdict</th><th>Score</th></tr>\n"
|
||||
for e in enrichments[:50]:
|
||||
verdict_class = f"badge-{e['verdict']}"
|
||||
html += (
|
||||
f"<tr><td><code>{e['ioc_value']}</code></td>"
|
||||
f"<td>{e['ioc_type']}</td>"
|
||||
f"<td>{e['source']}</td>"
|
||||
f"<td><span class='badge {verdict_class}'>{e['verdict']}</span></td>"
|
||||
f"<td>{e.get('score', 0)}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
html += f"""
|
||||
<div class="footer">
|
||||
<p>Generated by ThreatHunt Report Engine | {meta.get('generated_at', '')[:19]}</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
return html
|
||||
|
||||
def _render_csv(self, data: dict) -> str:
|
||||
"""Render key report data as CSV."""
|
||||
output = io.StringIO()
|
||||
|
||||
# Hypotheses sheet
|
||||
output.write("=== HYPOTHESES ===\n")
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Title", "MITRE Technique", "Status", "Description", "Evidence Notes"])
|
||||
for h in data.get("hypotheses", []):
|
||||
writer.writerow([
|
||||
h.get("title", ""),
|
||||
h.get("mitre_technique", ""),
|
||||
h.get("status", ""),
|
||||
h.get("description", ""),
|
||||
h.get("evidence_notes", ""),
|
||||
])
|
||||
|
||||
output.write("\n=== ANNOTATIONS ===\n")
|
||||
writer.writerow(["Severity", "Tag", "Text", "Dataset ID", "Row ID", "Created"])
|
||||
for a in data.get("annotations", []):
|
||||
writer.writerow([
|
||||
a.get("severity", ""),
|
||||
a.get("tag", ""),
|
||||
a.get("text", ""),
|
||||
a.get("dataset_id", ""),
|
||||
a.get("row_id", ""),
|
||||
a.get("created_at", ""),
|
||||
])
|
||||
|
||||
output.write("\n=== ENRICHMENTS ===\n")
|
||||
writer.writerow(["IOC Value", "IOC Type", "Source", "Verdict", "Score", "Country"])
|
||||
for e in data.get("enrichments", []):
|
||||
writer.writerow([
|
||||
e.get("ioc_value", ""),
|
||||
e.get("ioc_type", ""),
|
||||
e.get("source", ""),
|
||||
e.get("verdict", ""),
|
||||
e.get("score", ""),
|
||||
e.get("country", ""),
|
||||
])
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
# Singleton
|
||||
report_generator = ReportGenerator()
|
||||
346
backend/app/services/sans_rag.py
Normal file
346
backend/app/services/sans_rag.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""SANS RAG service — queries the 300GB SANS courseware indexed in Open WebUI.
|
||||
|
||||
Provides contextual SANS references for threat hunting guidance.
|
||||
Uses two approaches:
|
||||
1. Open WebUI RAG pipeline (if configured with a knowledge collection)
|
||||
2. Embedding-based semantic search against locally indexed SANS content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from app.agents.providers_v2 import _get_client
|
||||
from app.agents.registry import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── SANS course catalog for reference matching ────────────────────────
|
||||
|
||||
SANS_COURSES = {
|
||||
"SEC401": "Security Essentials",
|
||||
"SEC504": "Hacker Tools, Techniques, and Incident Handling",
|
||||
"SEC503": "Network Monitoring and Threat Detection In-Depth",
|
||||
"SEC505": "Securing Windows and PowerShell Automation",
|
||||
"SEC506": "Securing Linux/Unix",
|
||||
"SEC510": "Public Cloud Security: AWS, Azure, and GCP",
|
||||
"SEC511": "Continuous Monitoring and Security Operations",
|
||||
"SEC530": "Defensible Security Architecture and Engineering",
|
||||
"SEC540": "Cloud Security and DevSecOps Automation",
|
||||
"SEC555": "SIEM with Tactical Analytics",
|
||||
"SEC560": "Enterprise Penetration Testing",
|
||||
"SEC565": "Red Team Operations and Adversary Emulation",
|
||||
"SEC573": "Automating Information Security with Python",
|
||||
"SEC575": "Mobile Device Security and Ethical Hacking",
|
||||
"SEC588": "Cloud Penetration Testing",
|
||||
"SEC599": "Defeating Advanced Adversaries - Purple Team Tactics",
|
||||
"FOR408": "Windows Forensic Analysis",
|
||||
"FOR498": "Digital Acquisition and Rapid Triage",
|
||||
"FOR500": "Windows Forensic Analysis",
|
||||
"FOR508": "Advanced Incident Response, Threat Hunting, and Digital Forensics",
|
||||
"FOR509": "Enterprise Cloud Forensics and Incident Response",
|
||||
"FOR518": "Mac and iOS Forensic Analysis and Incident Response",
|
||||
"FOR572": "Advanced Network Forensics: Threat Hunting, Analysis, and Incident Response",
|
||||
"FOR578": "Cyber Threat Intelligence",
|
||||
"FOR585": "Smartphone Forensic Analysis In-Depth",
|
||||
"FOR610": "Reverse-Engineering Malware: Malware Analysis Tools and Techniques",
|
||||
"FOR710": "Reverse-Engineering Malware: Advanced Code Analysis",
|
||||
"ICS410": "ICS/SCADA Security Essentials",
|
||||
"ICS515": "ICS Visibility, Detection, and Response",
|
||||
}
|
||||
|
||||
# Topic-to-course mapping for fallback recommendations
|
||||
TOPIC_COURSE_MAP = {
|
||||
"malware": ["FOR610", "FOR710", "SEC504"],
|
||||
"reverse engineer": ["FOR610", "FOR710"],
|
||||
"incident response": ["FOR508", "SEC504"],
|
||||
"forensic": ["FOR508", "FOR500", "FOR408"],
|
||||
"windows forensic": ["FOR500", "FOR408"],
|
||||
"network forensic": ["FOR572"],
|
||||
"threat hunting": ["FOR508", "SEC504", "FOR578"],
|
||||
"threat intelligence": ["FOR578"],
|
||||
"powershell": ["SEC505", "FOR508"],
|
||||
"lateral movement": ["SEC504", "FOR508"],
|
||||
"persistence": ["FOR508", "SEC504"],
|
||||
"privilege escalation": ["SEC504", "SEC560"],
|
||||
"credential": ["SEC504", "SEC560"],
|
||||
"memory forensic": ["FOR508"],
|
||||
"disk forensic": ["FOR500", "FOR408"],
|
||||
"registry": ["FOR500", "FOR408"],
|
||||
"event log": ["FOR508", "SEC555"],
|
||||
"siem": ["SEC555"],
|
||||
"log analysis": ["SEC555", "SEC503"],
|
||||
"network monitor": ["SEC503"],
|
||||
"pcap": ["SEC503", "FOR572"],
|
||||
"cloud": ["SEC510", "SEC540", "FOR509"],
|
||||
"aws": ["SEC510", "SEC540", "FOR509"],
|
||||
"azure": ["SEC510", "FOR509"],
|
||||
"linux": ["SEC506"],
|
||||
"mobile": ["SEC575", "FOR585"],
|
||||
"penetration test": ["SEC560", "SEC565"],
|
||||
"red team": ["SEC565", "SEC599"],
|
||||
"purple team": ["SEC599"],
|
||||
"python": ["SEC573"],
|
||||
"automation": ["SEC573", "SEC540"],
|
||||
"deobfusc": ["FOR610", "SEC504"],
|
||||
"base64": ["FOR610", "SEC504"],
|
||||
"shellcode": ["FOR610", "FOR710"],
|
||||
"ransomware": ["FOR508", "FOR610"],
|
||||
"phishing": ["SEC504", "FOR578"],
|
||||
"c2": ["FOR508", "SEC504", "FOR572"],
|
||||
"command and control": ["FOR508", "SEC504"],
|
||||
"exfiltration": ["FOR508", "FOR572", "SEC503"],
|
||||
"dns": ["FOR572", "SEC503"],
|
||||
"ioc": ["FOR508", "FOR578"],
|
||||
"mitre": ["FOR508", "SEC504", "SEC599"],
|
||||
"att&ck": ["FOR508", "SEC504"],
|
||||
"velociraptor": ["FOR508"],
|
||||
"volatility": ["FOR508"],
|
||||
"scheduled task": ["FOR508", "SEC504"],
|
||||
"service": ["FOR508", "SEC504"],
|
||||
"wmi": ["FOR508", "SEC504"],
|
||||
"process": ["FOR508"],
|
||||
"dll": ["FOR610", "FOR508"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGResult:
|
||||
"""Result from a RAG query."""
|
||||
query: str
|
||||
context: str # Retrieved relevant text
|
||||
sources: list[str] = field(default_factory=list) # Source document names
|
||||
course_references: list[str] = field(default_factory=list) # SANS course IDs
|
||||
confidence: float = 0.0
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
class SANSRAGService:
|
||||
"""Service for querying SANS courseware via Open WebUI RAG pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
self.openwebui_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
self.rag_model = settings.DEFAULT_FAST_MODEL
|
||||
self._available: bool | None = None
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def query(
|
||||
self,
|
||||
question: str,
|
||||
context: str = "",
|
||||
max_tokens: int = 1024,
|
||||
) -> RAGResult:
|
||||
"""Query SANS courseware for relevant context.
|
||||
|
||||
Uses Open WebUI's RAG-enabled chat to retrieve from indexed SANS content.
|
||||
Falls back to topic-based course recommendations if RAG is unavailable.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Try Open WebUI RAG pipeline first
|
||||
try:
|
||||
result = await self._query_openwebui_rag(question, context, max_tokens)
|
||||
result.latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Enrich with course references if not already present
|
||||
if not result.course_references:
|
||||
result.course_references = self._match_courses(question)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG query failed, using fallback: {e}")
|
||||
# Fallback to topic-based matching
|
||||
courses = self._match_courses(question)
|
||||
return RAGResult(
|
||||
query=question,
|
||||
context="",
|
||||
sources=[],
|
||||
course_references=courses,
|
||||
confidence=0.3 if courses else 0.0,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
async def _query_openwebui_rag(
|
||||
self,
|
||||
question: str,
|
||||
context: str,
|
||||
max_tokens: int,
|
||||
) -> RAGResult:
|
||||
"""Query Open WebUI with RAG context retrieval.
|
||||
|
||||
Open WebUI automatically retrieves from its indexed knowledge base
|
||||
when the model is configured with a knowledge collection.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
system_msg = (
|
||||
"You are a SANS cybersecurity knowledge assistant. "
|
||||
"Use your indexed SANS courseware to answer the question. "
|
||||
"Always cite the specific SANS course (e.g., FOR508, SEC504) "
|
||||
"and relevant section when referencing material. "
|
||||
"If the question relates to threat hunting procedures, "
|
||||
"reference the specific SANS methodology or framework."
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_msg},
|
||||
]
|
||||
|
||||
if context:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"Investigation context:\n{context}\n\nQuestion: {question}",
|
||||
})
|
||||
else:
|
||||
messages.append({"role": "user", "content": question})
|
||||
|
||||
payload = {
|
||||
"model": self.rag_model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.2,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
resp = await client.post(
|
||||
f"{self.openwebui_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
# Extract course references from response
|
||||
course_refs = self._extract_course_refs(content)
|
||||
sources = self._extract_sources(data)
|
||||
|
||||
return RAGResult(
|
||||
query=question,
|
||||
context=content,
|
||||
sources=sources,
|
||||
course_references=course_refs,
|
||||
confidence=0.8 if content else 0.0,
|
||||
)
|
||||
|
||||
def _extract_course_refs(self, text: str) -> list[str]:
|
||||
"""Extract SANS course references from response text."""
|
||||
refs = set()
|
||||
# Match patterns like SEC504, FOR508, ICS410
|
||||
pattern = r'\b(SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
# Need to get the full match
|
||||
full_pattern = r'\b(?:SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
|
||||
full_matches = re.findall(full_pattern, text, re.IGNORECASE)
|
||||
for m in full_matches:
|
||||
course_id = m.upper()
|
||||
if course_id in SANS_COURSES:
|
||||
refs.add(f"{course_id}: {SANS_COURSES[course_id]}")
|
||||
else:
|
||||
refs.add(course_id)
|
||||
return sorted(refs)
|
||||
|
||||
def _extract_sources(self, api_response: dict) -> list[str]:
|
||||
"""Extract source document references from Open WebUI response metadata."""
|
||||
sources = []
|
||||
# Open WebUI may include source metadata in various formats
|
||||
if "sources" in api_response:
|
||||
for src in api_response["sources"]:
|
||||
if isinstance(src, dict):
|
||||
sources.append(src.get("name", src.get("title", str(src))))
|
||||
else:
|
||||
sources.append(str(src))
|
||||
# Check in metadata
|
||||
for choice in api_response.get("choices", []):
|
||||
meta = choice.get("metadata", {})
|
||||
if "sources" in meta:
|
||||
for src in meta["sources"]:
|
||||
if isinstance(src, dict):
|
||||
sources.append(src.get("name", str(src)))
|
||||
else:
|
||||
sources.append(str(src))
|
||||
return sources[:10] # Limit
|
||||
|
||||
def _match_courses(self, query: str) -> list[str]:
|
||||
"""Match query keywords to SANS courses using topic map."""
|
||||
q = query.lower()
|
||||
matched = set()
|
||||
for topic, courses in TOPIC_COURSE_MAP.items():
|
||||
if topic in q:
|
||||
for course_id in courses:
|
||||
if course_id in SANS_COURSES:
|
||||
matched.add(f"{course_id}: {SANS_COURSES[course_id]}")
|
||||
return sorted(matched)[:5]
|
||||
|
||||
async def get_course_context(self, course_id: str) -> str:
|
||||
"""Get a brief course description for context injection."""
|
||||
course_id = course_id.upper().split(":")[0].strip()
|
||||
if course_id in SANS_COURSES:
|
||||
return f"{course_id}: {SANS_COURSES[course_id]}"
|
||||
return ""
|
||||
|
||||
async def enrich_prompt(
|
||||
self,
|
||||
query: str,
|
||||
investigation_context: str = "",
|
||||
) -> str:
|
||||
"""Generate SANS-enriched context to inject into agent prompts.
|
||||
|
||||
Returns a context string with relevant SANS references.
|
||||
"""
|
||||
result = await self.query(query, context=investigation_context, max_tokens=512)
|
||||
|
||||
parts = []
|
||||
if result.context:
|
||||
parts.append(f"SANS Reference Context:\n{result.context}")
|
||||
if result.course_references:
|
||||
parts.append(f"Relevant SANS Courses: {', '.join(result.course_references)}")
|
||||
if result.sources:
|
||||
parts.append(f"Sources: {', '.join(result.sources[:5])}")
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
async def health_check(self) -> dict:
|
||||
"""Check RAG service availability."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.openwebui_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
available = resp.status_code == 200
|
||||
self._available = available
|
||||
return {
|
||||
"available": available,
|
||||
"url": self.openwebui_url,
|
||||
"model": self.rag_model,
|
||||
}
|
||||
except Exception as e:
|
||||
self._available = False
|
||||
return {
|
||||
"available": False,
|
||||
"url": self.openwebui_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
sans_rag = SANSRAGService()
|
||||
233
backend/app/services/scanner.py
Normal file
233
backend/app/services/scanner.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
Keyword,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanHit:
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
total_hits: int = 0
|
||||
hits: list[ScanHit] = field(default_factory=list)
|
||||
themes_scanned: int = 0
|
||||
keywords_scanned: int = 0
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = True,
|
||||
scan_annotations: bool = True,
|
||||
scan_messages: bool = True,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
themes = await self._load_themes(theme_ids)
|
||||
if not themes:
|
||||
return ScanResult().__dict__
|
||||
|
||||
# Pre-compile patterns per theme
|
||||
patterns = self._compile_patterns(themes)
|
||||
result = ScanResult(
|
||||
themes_scanned=len(themes),
|
||||
keywords_scanned=sum(len(kws) for kws in patterns.values()),
|
||||
)
|
||||
|
||||
# Scan dataset rows
|
||||
await self._scan_datasets(patterns, result, dataset_ids)
|
||||
|
||||
# Scan hunts
|
||||
if scan_hunts:
|
||||
await self._scan_hunts(patterns, result)
|
||||
|
||||
# Scan annotations
|
||||
if scan_annotations:
|
||||
await self._scan_annotations(patterns, result)
|
||||
|
||||
# Scan messages
|
||||
if scan_messages:
|
||||
await self._scan_messages(patterns, result)
|
||||
|
||||
result.total_hits = len(result.hits)
|
||||
return {
|
||||
"total_hits": result.total_hits,
|
||||
"hits": [h.__dict__ for h in result.hits],
|
||||
"themes_scanned": result.themes_scanned,
|
||||
"keywords_scanned": result.keywords_scanned,
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
if theme_ids:
|
||||
q = q.where(KeywordTheme.id.in_(theme_ids))
|
||||
result = await self.db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def _compile_patterns(
|
||||
self, themes: list[KeywordTheme]
|
||||
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
|
||||
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
|
||||
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
|
||||
for theme in themes:
|
||||
key = (theme.id, theme.name, theme.color)
|
||||
compiled = []
|
||||
for kw in theme.keywords:
|
||||
try:
|
||||
if kw.is_regex:
|
||||
pat = re.compile(kw.value, re.IGNORECASE)
|
||||
else:
|
||||
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
|
||||
compiled.append((kw.value, pat))
|
||||
except re.error:
|
||||
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
|
||||
kw.value, theme.name)
|
||||
patterns[key] = compiled
|
||||
return patterns
|
||||
|
||||
def _match_text(
|
||||
self,
|
||||
text: str,
|
||||
patterns: dict,
|
||||
source_type: str,
|
||||
source_id: str | int,
|
||||
field_name: str,
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
return
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
# Truncate matched_value for display
|
||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
keyword=kw_value,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
field=field_name,
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
# Build dataset name lookup
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
ds_result = await self.db.execute(ds_q)
|
||||
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
# Iterate rows in batches
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
hunts_result = await self.db.execute(select(Hunt))
|
||||
for hunt in hunts_result.scalars().all():
|
||||
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
|
||||
if hunt.description:
|
||||
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
|
||||
|
||||
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan annotation text."""
|
||||
ann_result = await self.db.execute(select(Annotation))
|
||||
for ann in ann_result.scalars().all():
|
||||
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
|
||||
|
||||
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan conversation messages (user messages only)."""
|
||||
msg_result = await self.db.execute(
|
||||
select(Message).where(Message.role == "user")
|
||||
)
|
||||
for msg in msg_result.scalars().all():
|
||||
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
|
||||
Reference in New Issue
Block a user