mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates. - Added SavedSearches component for managing bookmarked queries and recurring scans. - Introduced TimelineView for visualizing forensic event timelines with zoomable charts. - Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency. - Updated frontend components for better user experience and performance optimizations. - Documented changes in update log for future reference.
This commit is contained in:
@@ -13,6 +13,7 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
|
||||
return name or ''
|
||||
|
||||
|
||||
|
||||
|
||||
# In-memory host inventory cache
|
||||
# Pre-computed results stored per hunt_id, built in background after upload.
|
||||
|
||||
import time as _time
|
||||
|
||||
class _InventoryCache:
|
||||
"""Simple in-memory cache for pre-computed host inventories."""
|
||||
|
||||
def __init__(self):
|
||||
self._data: dict[str, dict] = {} # hunt_id -> result dict
|
||||
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
|
||||
self._building: set[str] = set() # hunt_ids currently being built
|
||||
|
||||
def get(self, hunt_id: str) -> dict | None:
|
||||
"""Return cached result if present. Never expires; only invalidated on new upload."""
|
||||
return self._data.get(hunt_id)
|
||||
|
||||
def put(self, hunt_id: str, result: dict):
|
||||
self._data[hunt_id] = result
|
||||
self._timestamps[hunt_id] = _time.time()
|
||||
self._building.discard(hunt_id)
|
||||
logger.info(f"Cached host inventory for hunt {hunt_id} "
|
||||
f"({result['stats']['total_hosts']} hosts)")
|
||||
|
||||
def invalidate(self, hunt_id: str):
|
||||
self._data.pop(hunt_id, None)
|
||||
self._timestamps.pop(hunt_id, None)
|
||||
|
||||
def is_building(self, hunt_id: str) -> bool:
|
||||
return hunt_id in self._building
|
||||
|
||||
def set_building(self, hunt_id: str):
|
||||
self._building.add(hunt_id)
|
||||
|
||||
def clear_building(self, hunt_id: str):
|
||||
self._building.discard(hunt_id)
|
||||
|
||||
def status(self, hunt_id: str) -> str:
|
||||
if hunt_id in self._building:
|
||||
return "building"
|
||||
if hunt_id in self._data:
|
||||
return "ready"
|
||||
return "none"
|
||||
|
||||
|
||||
inventory_cache = _InventoryCache()
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
@@ -151,33 +201,61 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
sampled_dataset_count = 0
|
||||
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||
global_budget_reached = False
|
||||
dropped_connections = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
|
||||
while True:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.where(DatasetRow.row_index > last_row_index)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
conn_key = (host_key, rip, rport)
|
||||
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||
dropped_connections += 1
|
||||
continue
|
||||
connections[conn_key] += 1
|
||||
|
||||
offset += batch_size
|
||||
if sampled_dataset:
|
||||
sampled_dataset_count += 1
|
||||
logger.info(
|
||||
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
|
||||
last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
if global_budget_reached:
|
||||
logger.info(
|
||||
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||
hunt_id,
|
||||
total_rows,
|
||||
)
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||
"sampled_datasets": sampled_dataset_count,
|
||||
"global_budget_reached": global_budget_reached,
|
||||
"dropped_connections": dropped_connections,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
# Velociraptor client IDs (C.hex) are not real hostnames
|
||||
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
@@ -154,7 +158,7 @@ async def profile_host(
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
logger.error("Failed to profile host %s: %r", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
# Filter out Velociraptor client IDs - not real hostnames
|
||||
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
|
||||
skipped = len(hostnames) - len(real_hosts)
|
||||
if skipped:
|
||||
logger.info("Skipped %d Velociraptor client IDs", skipped)
|
||||
hostnames = real_hosts
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
keyword scanning, IOC extraction, and data queries as trackable
|
||||
jobs with status, progress, and cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,6 +34,18 @@ class JobType(str, Enum):
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
HOST_INVENTORY = "host_inventory"
|
||||
KEYWORD_SCAN = "keyword_scan"
|
||||
IOC_EXTRACT = "ioc_extract"
|
||||
|
||||
|
||||
# Job types that form the automatic upload pipeline
|
||||
PIPELINE_JOB_TYPES = frozenset({
|
||||
JobType.TRIAGE,
|
||||
JobType.ANOMALY,
|
||||
JobType.KEYWORD_SCAN,
|
||||
JobType.IOC_EXTRACT,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,11 +96,7 @@ class Job:
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
"""In-memory async job queue with concurrency control."""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
@@ -95,47 +105,56 @@ class JobQueue:
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
def on_completion(self, callback: Callable[[Job], Coroutine]):
|
||||
"""Register a callback invoked after any job completes or fails."""
|
||||
self._completion_callbacks.append(callback)
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
if not self._cleanup_task or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||
self._cleanup_task = None
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
)
|
||||
# Soft backpressure: prefer dedupe over queue amplification
|
||||
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||
if dedupe_job is not None:
|
||||
logger.info(
|
||||
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||
)
|
||||
return dedupe_job
|
||||
|
||||
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||
logger.warning(
|
||||
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||
)
|
||||
|
||||
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
@@ -144,6 +163,22 @@ class JobQueue:
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||
if not sig:
|
||||
return None
|
||||
for j in self._jobs.values():
|
||||
if j.job_type != job_type:
|
||||
continue
|
||||
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||
continue
|
||||
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||
if sig == other_sig:
|
||||
return j
|
||||
return None
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
@@ -153,13 +188,7 @@ class JobQueue:
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
@@ -168,7 +197,6 @@ class JobQueue:
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
@@ -177,26 +205,58 @@ class JobQueue:
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
|
||||
}
|
||||
|
||||
def is_backlogged(self) -> bool:
|
||||
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def can_accept(self, reserve: int = 0) -> bool:
|
||||
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
|
||||
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||
terminal_jobs = sorted(
|
||||
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||
key=lambda j: j.created_at,
|
||||
reverse=True,
|
||||
)
|
||||
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||
to_remove.extend([j.id for j in overflow])
|
||||
|
||||
removed = 0
|
||||
for jid in set(to_remove):
|
||||
if jid in self._jobs:
|
||||
del self._jobs[jid]
|
||||
removed += 1
|
||||
if removed:
|
||||
logger.info(f"Cleaned up {removed} old jobs")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||
while self._started:
|
||||
try:
|
||||
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
|
||||
"""Find all pipeline jobs for a given dataset_id."""
|
||||
return [
|
||||
j for j in self._jobs.values()
|
||||
if j.job_type in PIPELINE_JOB_TYPES
|
||||
and j.params.get("dataset_id") == dataset_id
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
@@ -220,7 +280,10 @@ class JobQueue:
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
if job.progress <= 0:
|
||||
job.progress = 5.0
|
||||
job.message = "Running..."
|
||||
await _sync_processing_task(job)
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
@@ -231,38 +294,111 @@ class JobQueue:
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||
|
||||
if job.is_cancelled and not job.completed_at:
|
||||
job.completed_at = time.time()
|
||||
|
||||
await _sync_processing_task(job)
|
||||
|
||||
# Fire completion callbacks
|
||||
for cb in self._completion_callbacks:
|
||||
try:
|
||||
await cb(job)
|
||||
except Exception as cb_err:
|
||||
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
async def _sync_processing_task(job: Job):
|
||||
"""Persist latest job state into processing_tasks (if linked by job_id)."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
values = {
|
||||
"status": job.status.value,
|
||||
"progress": float(job.progress),
|
||||
"message": job.message,
|
||||
"error": job.error,
|
||||
}
|
||||
if job.started_at:
|
||||
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
|
||||
if job.completed_at:
|
||||
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.job_id == job.id)
|
||||
.values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
|
||||
|
||||
|
||||
# -- Singleton + job handlers --
|
||||
|
||||
job_queue = JobQueue(max_workers=5)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
"""Triage handler - chains HOST_PROFILE after completion."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
# Chain: trigger host profiling now that triage results exist
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
|
||||
row = ds.first()
|
||||
hunt_id = row[0] if row else None
|
||||
if hunt_id:
|
||||
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import ProcessingTask
|
||||
async with async_session_factory() as db:
|
||||
existing = await db.execute(
|
||||
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||
)
|
||||
if existing.first() is None:
|
||||
db.add(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=hp_job.id,
|
||||
stage="host_profile",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
await db.commit()
|
||||
except Exception as persist_err:
|
||||
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||
|
||||
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to chain host profile after triage: {e}")
|
||||
|
||||
return {"dataset_id": dataset_id}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
async def _handle_host_inventory(job: Job):
|
||||
from app.db import async_session_factory
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
if not hunt_id:
|
||||
raise ValueError("hunt_id required")
|
||||
|
||||
inventory_cache.set_building(hunt_id)
|
||||
job.message = f"Building host inventory for hunt {hunt_id}"
|
||||
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
inventory_cache.put(hunt_id, result)
|
||||
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
|
||||
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
|
||||
except Exception:
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_keyword_scan(job: Job):
|
||||
"""AUP keyword scan handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
|
||||
# Cache dataset-only result for fast API reuse
|
||||
if dataset_id:
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
|
||||
hits = result.get("total_hits", 0)
|
||||
job.message = f"Keyword scan complete: {hits} hits"
|
||||
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||
|
||||
|
||||
async def _handle_ioc_extract(job: Job):
|
||||
"""IOC extraction handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Extracting IOCs from dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db)
|
||||
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
job.message = f"IOC extraction complete: {total} IOCs found"
|
||||
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
|
||||
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
|
||||
|
||||
|
||||
async def _on_pipeline_job_complete(job: Job):
|
||||
"""Update Dataset.processing_status when all pipeline jobs finish."""
|
||||
if job.job_type not in PIPELINE_JOB_TYPES:
|
||||
return
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
if not dataset_id:
|
||||
return
|
||||
|
||||
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
|
||||
if not pipeline_jobs:
|
||||
return
|
||||
|
||||
all_done = all(
|
||||
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
for j in pipeline_jobs
|
||||
)
|
||||
if not all_done:
|
||||
return
|
||||
|
||||
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
|
||||
new_status = "completed_with_errors" if any_failed else "completed"
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import update
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id == dataset_id)
|
||||
.values(processing_status=new_status)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def reconcile_stale_processing_tasks() -> int:
|
||||
"""Mark queued/running processing tasks from prior runs as failed."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.status.in_(["queued", "running"]))
|
||||
.values(
|
||||
status="failed",
|
||||
error="Recovered after service restart before task completion",
|
||||
message="Recovered stale task after restart",
|
||||
completed_at=now,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
updated = int(result.rowcount or 0)
|
||||
|
||||
if updated:
|
||||
logger.warning(
|
||||
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
|
||||
updated,
|
||||
)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
"""Register all job handlers and completion callbacks."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
|
||||
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
|
||||
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
|
||||
job_queue.on_completion(_on_pipeline_job_complete)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
||||
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
Keyword,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 500
|
||||
BATCH_SIZE = 200
|
||||
|
||||
|
||||
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||
if not data:
|
||||
return None, None
|
||||
|
||||
host_keys = (
|
||||
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||
)
|
||||
user_keys = (
|
||||
'username', 'user_name', 'user', 'account_name',
|
||||
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||
)
|
||||
|
||||
def pick(keys):
|
||||
for k in keys:
|
||||
for actual_key, v in data.items():
|
||||
if actual_key.lower() == k and v not in (None, ''):
|
||||
return str(v)
|
||||
return None
|
||||
|
||||
return pick(host_keys), pick(user_keys)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,6 +64,8 @@ class ScanHit:
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,21 +77,54 @@ class ScanResult:
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordScanCacheEntry:
|
||||
dataset_id: str
|
||||
result: dict
|
||||
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
class KeywordScanCache:
|
||||
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||
|
||||
This enables fast-path reads when users run AUP scans against datasets that
|
||||
were already scanned during upload pipeline processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||
|
||||
def put(self, dataset_id: str, result: dict):
|
||||
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||
|
||||
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||
return self._entries.get(dataset_id)
|
||||
|
||||
def invalidate_dataset(self, dataset_id: str):
|
||||
self._entries.pop(dataset_id, None)
|
||||
|
||||
def clear(self):
|
||||
self._entries.clear()
|
||||
|
||||
|
||||
keyword_scan_cache = KeywordScanCache()
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
# Public API
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = True,
|
||||
scan_annotations: bool = True,
|
||||
scan_messages: bool = True,
|
||||
scan_hunts: bool = False,
|
||||
scan_annotations: bool = False,
|
||||
scan_messages: bool = False,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
@@ -103,7 +163,7 @@ class KeywordScanner:
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────
|
||||
# Internal
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
@@ -143,6 +203,8 @@ class KeywordScanner:
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
hostname: str | None = None,
|
||||
username: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
@@ -150,8 +212,7 @@ class KeywordScanner:
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
# Truncate matched_value for display
|
||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
||||
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
@@ -162,13 +223,14 @@ class KeywordScanner:
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
# Build dataset name lookup
|
||||
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
@@ -178,37 +240,66 @@ class KeywordScanner:
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
# Iterate rows in batches
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
import asyncio
|
||||
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||
budget_reached = False
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
|
||||
last_id = 0
|
||||
while True:
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
rows_result = await self.db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds_id)
|
||||
.where(DatasetRow.id > last_id)
|
||||
.order_by(DatasetRow.id)
|
||||
.limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
hostname, username = _infer_hostname_and_user(data)
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
)
|
||||
|
||||
last_id = rows[-1].id
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
logger.warning(
|
||||
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||
result.rows_scanned,
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
|
||||
ARTIFACT_FOCUS = {
|
||||
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_number)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
@@ -167,4 +167,4 @@ Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."
|
||||
|
||||
offset += batch_size
|
||||
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
|
||||
Reference in New Issue
Block a user