Files
ThreatHunt/backend/app/services/job_queue.py
mblanke 5a2ad8ec1c feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:23:07 -05:00

592 lines
21 KiB
Python

"""Async job queue for background AI tasks.
Manages triage, profiling, report generation, anomaly detection,
keyword scanning, IOC extraction, and data queries as trackable
jobs with status, progress, and cancellation support.
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine, Optional
from app.config import settings
logger = logging.getLogger(__name__)
class JobStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class JobType(str, Enum):
TRIAGE = "triage"
HOST_PROFILE = "host_profile"
REPORT = "report"
ANOMALY = "anomaly"
QUERY = "query"
HOST_INVENTORY = "host_inventory"
KEYWORD_SCAN = "keyword_scan"
IOC_EXTRACT = "ioc_extract"
# Job types that form the automatic upload pipeline
PIPELINE_JOB_TYPES = frozenset({
JobType.TRIAGE,
JobType.ANOMALY,
JobType.KEYWORD_SCAN,
JobType.IOC_EXTRACT,
})
@dataclass
class Job:
id: str
job_type: JobType
status: JobStatus = JobStatus.QUEUED
progress: float = 0.0 # 0-100
message: str = ""
result: Any = None
error: str | None = None
created_at: float = field(default_factory=time.time)
started_at: float | None = None
completed_at: float | None = None
params: dict = field(default_factory=dict)
_cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
@property
def elapsed_ms(self) -> int:
end = self.completed_at or time.time()
start = self.started_at or self.created_at
return int((end - start) * 1000)
def to_dict(self) -> dict:
return {
"id": self.id,
"job_type": self.job_type.value,
"status": self.status.value,
"progress": round(self.progress, 1),
"message": self.message,
"error": self.error,
"created_at": self.created_at,
"started_at": self.started_at,
"completed_at": self.completed_at,
"elapsed_ms": self.elapsed_ms,
"params": self.params,
}
@property
def is_cancelled(self) -> bool:
return self._cancel_event.is_set()
def cancel(self):
self._cancel_event.set()
self.status = JobStatus.CANCELLED
self.completed_at = time.time()
self.message = "Cancelled by user"
class JobQueue:
"""In-memory async job queue with concurrency control."""
def __init__(self, max_workers: int = 3):
self._jobs: dict[str, Job] = {}
self._queue: asyncio.Queue[str] = asyncio.Queue()
self._max_workers = max_workers
self._workers: list[asyncio.Task] = []
self._handlers: dict[JobType, Callable] = {}
self._started = False
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
self._cleanup_task: asyncio.Task | None = None
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
self._handlers[job_type] = handler
logger.info(f"Registered handler for {job_type.value}")
def on_completion(self, callback: Callable[[Job], Coroutine]):
"""Register a callback invoked after any job completes or fails."""
self._completion_callbacks.append(callback)
async def start(self):
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(f"Job queue started with {self._max_workers} workers")
async def stop(self):
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
if self._cleanup_task:
self._cleanup_task.cancel()
await asyncio.gather(self._cleanup_task, return_exceptions=True)
self._cleanup_task = None
logger.info("Job queue stopped")
def submit(self, job_type: JobType, **params) -> Job:
# Soft backpressure: prefer dedupe over queue amplification
dedupe_job = self._find_active_duplicate(job_type, params)
if dedupe_job is not None:
logger.info(
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
)
return dedupe_job
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
logger.warning(
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
)
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
return job
def get_job(self, job_id: str) -> Job | None:
return self._jobs.get(job_id)
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
"""Return queued/running job with same key workload to prevent duplicate storms."""
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
if not sig:
return None
for j in self._jobs.values():
if j.job_type != job_type:
continue
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
continue
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
if sig == other_sig:
return j
return None
def cancel_job(self, job_id: str) -> bool:
job = self._jobs.get(job_id)
if not job:
return False
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
return False
job.cancel()
return True
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
if status:
jobs = [j for j in jobs if j.status == status]
if job_type:
jobs = [j for j in jobs if j.job_type == job_type]
return [j.to_dict() for j in jobs[:limit]]
def get_stats(self) -> dict:
by_status = {}
for j in self._jobs.values():
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
return {
"total": len(self._jobs),
"queued": self._queue.qsize(),
"by_status": by_status,
"workers": self._max_workers,
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
}
def is_backlogged(self) -> bool:
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
def can_accept(self, reserve: int = 0) -> bool:
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
def cleanup(self, max_age_seconds: float = 3600):
now = time.time()
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
]
# Also cap retained terminal jobs to avoid unbounded memory growth
terminal_jobs = sorted(
[j for j in self._jobs.values() if j.status in terminal_states],
key=lambda j: j.created_at,
reverse=True,
)
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
to_remove.extend([j.id for j in overflow])
removed = 0
for jid in set(to_remove):
if jid in self._jobs:
del self._jobs[jid]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
async def _cleanup_loop(self):
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
while self._started:
try:
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
except Exception as e:
logger.warning(f"Job queue cleanup loop error: {e}")
await asyncio.sleep(interval)
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
"""Find all pipeline jobs for a given dataset_id."""
return [
j for j in self._jobs.values()
if j.job_type in PIPELINE_JOB_TYPES
and j.params.get("dataset_id") == dataset_id
]
async def _worker(self, worker_id: int):
logger.info(f"Worker {worker_id} started")
while self._started:
try:
job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
job = self._jobs.get(job_id)
if not job or job.is_cancelled:
continue
handler = self._handlers.get(job.job_type)
if not handler:
job.status = JobStatus.FAILED
job.error = f"No handler for {job.job_type.value}"
job.completed_at = time.time()
logger.error(f"No handler for job type {job.job_type.value}")
continue
job.status = JobStatus.RUNNING
job.started_at = time.time()
if job.progress <= 0:
job.progress = 5.0
job.message = "Running..."
await _sync_processing_task(job)
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
result = await handler(job)
if not job.is_cancelled:
job.status = JobStatus.COMPLETED
job.progress = 100.0
job.result = result
job.message = "Completed"
job.completed_at = time.time()
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
if job.is_cancelled and not job.completed_at:
job.completed_at = time.time()
await _sync_processing_task(job)
# Fire completion callbacks
for cb in self._completion_callbacks:
try:
await cb(job)
except Exception as cb_err:
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
async def _sync_processing_task(job: Job):
"""Persist latest job state into processing_tasks (if linked by job_id)."""
from datetime import datetime, timezone
from sqlalchemy import update
try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
values = {
"status": job.status.value,
"progress": float(job.progress),
"message": job.message,
"error": job.error,
}
if job.started_at:
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
if job.completed_at:
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
async with async_session_factory() as db:
await db.execute(
update(ProcessingTask)
.where(ProcessingTask.job_id == job.id)
.values(**values)
)
await db.commit()
except Exception as e:
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
# -- Singleton + job handlers --
job_queue = JobQueue(max_workers=5)
async def _handle_triage(job: Job):
"""Triage handler - chains HOST_PROFILE after completion."""
from app.services.triage import triage_dataset
dataset_id = job.params.get("dataset_id")
job.message = f"Triaging dataset {dataset_id}"
await triage_dataset(dataset_id)
# Chain: trigger host profiling now that triage results exist
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import select
try:
async with async_session_factory() as db:
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
row = ds.first()
hunt_id = row[0] if row else None
if hunt_id:
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
try:
from sqlalchemy import select
from app.db.models import ProcessingTask
async with async_session_factory() as db:
existing = await db.execute(
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
)
if existing.first() is None:
db.add(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset_id,
job_id=hp_job.id,
stage="host_profile",
status="queued",
progress=0.0,
message="Queued",
))
await db.commit()
except Exception as persist_err:
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
logger.warning(f"Failed to chain host profile after triage: {e}")
return {"dataset_id": dataset_id}
async def _handle_host_profile(job: Job):
from app.services.host_profiler import profile_all_hosts, profile_host
hunt_id = job.params.get("hunt_id")
hostname = job.params.get("hostname")
if hostname:
job.message = f"Profiling host {hostname}"
await profile_host(hunt_id, hostname)
return {"hostname": hostname}
else:
job.message = f"Profiling all hosts in hunt {hunt_id}"
await profile_all_hosts(hunt_id)
return {"hunt_id": hunt_id}
async def _handle_report(job: Job):
from app.services.report_generator import generate_report
hunt_id = job.params.get("hunt_id")
job.message = f"Generating report for hunt {hunt_id}"
report = await generate_report(hunt_id)
return {"report_id": report.id if report else None}
async def _handle_anomaly(job: Job):
from app.services.anomaly_detector import detect_anomalies
dataset_id = job.params.get("dataset_id")
k = job.params.get("k", 3)
threshold = job.params.get("threshold", 0.35)
job.message = f"Detecting anomalies in dataset {dataset_id}"
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
return {"count": len(results) if results else 0}
async def _handle_query(job: Job):
from app.services.data_query import query_dataset
dataset_id = job.params.get("dataset_id")
question = job.params.get("question", "")
mode = job.params.get("mode", "quick")
job.message = f"Querying dataset {dataset_id}"
answer = await query_dataset(dataset_id, question, mode)
return {"answer": answer}
async def _handle_host_inventory(job: Job):
from app.db import async_session_factory
from app.services.host_inventory import build_host_inventory, inventory_cache
hunt_id = job.params.get("hunt_id")
if not hunt_id:
raise ValueError("hunt_id required")
inventory_cache.set_building(hunt_id)
job.message = f"Building host inventory for hunt {hunt_id}"
try:
async with async_session_factory() as db:
result = await build_host_inventory(hunt_id, db)
inventory_cache.put(hunt_id, result)
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
except Exception:
inventory_cache.clear_building(hunt_id)
raise
async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner, keyword_scan_cache
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
# Cache dataset-only result for fast API reuse
if dataset_id:
keyword_scan_cache.put(dataset_id, result)
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
async def _handle_ioc_extract(job: Job):
"""IOC extraction handler."""
from app.db import async_session_factory
from app.services.ioc_extractor import extract_iocs_from_dataset
dataset_id = job.params.get("dataset_id")
job.message = f"Extracting IOCs from dataset {dataset_id}"
async with async_session_factory() as db:
iocs = await extract_iocs_from_dataset(dataset_id, db)
total = sum(len(v) for v in iocs.values())
job.message = f"IOC extraction complete: {total} IOCs found"
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
async def _on_pipeline_job_complete(job: Job):
"""Update Dataset.processing_status when all pipeline jobs finish."""
if job.job_type not in PIPELINE_JOB_TYPES:
return
dataset_id = job.params.get("dataset_id")
if not dataset_id:
return
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
if not pipeline_jobs:
return
all_done = all(
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
for j in pipeline_jobs
)
if not all_done:
return
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
new_status = "completed_with_errors" if any_failed else "completed"
try:
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import update
async with async_session_factory() as db:
await db.execute(
update(Dataset)
.where(Dataset.id == dataset_id)
.values(processing_status=new_status)
)
await db.commit()
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
except Exception as e:
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
async def reconcile_stale_processing_tasks() -> int:
"""Mark queued/running processing tasks from prior runs as failed."""
from datetime import datetime, timezone
from sqlalchemy import update
try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
now = datetime.now(timezone.utc)
async with async_session_factory() as db:
result = await db.execute(
update(ProcessingTask)
.where(ProcessingTask.status.in_(["queued", "running"]))
.values(
status="failed",
error="Recovered after service restart before task completion",
message="Recovered stale task after restart",
completed_at=now,
)
)
await db.commit()
updated = int(result.rowcount or 0)
if updated:
logger.warning(
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
updated,
)
return updated
except Exception as e:
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
return 0
def register_all_handlers():
"""Register all job handlers and completion callbacks."""
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
job_queue.register_handler(JobType.REPORT, _handle_report)
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
job_queue.register_handler(JobType.QUERY, _handle_query)
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
job_queue.on_completion(_on_pipeline_job_complete)