mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates. - Added SavedSearches component for managing bookmarked queries and recurring scans. - Introduced TimelineView for visualizing forensic event timelines with zoomable charts. - Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency. - Updated frontend components for better user experience and performance optimizations. - Documented changes in update log for future reference.
This commit is contained in:
350
fix_all.py
Normal file
350
fix_all.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Fix all critical issues: DB locking, keyword scan, network map."""
|
||||
import os, re
|
||||
|
||||
ROOT = r"D:\Projects\Dev\ThreatHunt"
|
||||
|
||||
def fix_file(filepath, replacements):
|
||||
"""Apply text replacements to a file."""
|
||||
path = os.path.join(ROOT, filepath)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
for old, new, desc in replacements:
|
||||
if old in content:
|
||||
content = content.replace(old, new, 1)
|
||||
print(f" OK: {desc}")
|
||||
else:
|
||||
print(f" SKIP: {desc} (pattern not found)")
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return content
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 1: Database engine - NullPool instead of StaticPool
|
||||
# ================================================================
|
||||
print("\n=== FIX 1: Database engine (NullPool + higher timeouts) ===")
|
||||
|
||||
engine_path = os.path.join(ROOT, "backend", "app", "db", "engine.py")
|
||||
with open(engine_path, "r", encoding="utf-8") as f:
|
||||
engine_content = f.read()
|
||||
|
||||
new_engine = '''"""Database engine, session factory, and base model.
|
||||
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||
|
||||
_engine_kwargs: dict = dict(
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||
# NullPool: each session gets its own connection.
|
||||
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||
from sqlalchemy.pool import NullPool
|
||||
_engine_kwargs["poolclass"] = NullPool
|
||||
else:
|
||||
_engine_kwargs["pool_size"] = 5
|
||||
_engine_kwargs["max_overflow"] = 10
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=30000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
# Alias expected by other modules
|
||||
async_session = async_session_factory
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine on shutdown."""
|
||||
await engine.dispose()
|
||||
'''
|
||||
|
||||
with open(engine_path, "w", encoding="utf-8") as f:
|
||||
f.write(new_engine)
|
||||
print(" OK: Replaced StaticPool with NullPool")
|
||||
print(" OK: Increased busy_timeout 5000 -> 30000ms")
|
||||
print(" OK: Added check_same_thread=False")
|
||||
print(" OK: Connection timeout 30 -> 60s")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 2: Keyword scan endpoint - make POST non-blocking (background job)
|
||||
# ================================================================
|
||||
print("\n=== FIX 2: Keyword scan endpoint -> background job ===")
|
||||
|
||||
kw_path = os.path.join(ROOT, "backend", "app", "api", "routes", "keywords.py")
|
||||
with open(kw_path, "r", encoding="utf-8") as f:
|
||||
kw_content = f.read()
|
||||
|
||||
# Replace the scan endpoint to be non-blocking
|
||||
old_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result'''
|
||||
|
||||
new_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources.
|
||||
|
||||
Uses a dedicated DB session separate from the request session
|
||||
to avoid blocking other API requests on SQLite.
|
||||
"""
|
||||
from app.db import async_session_factory
|
||||
async with async_session_factory() as scan_db:
|
||||
scanner = KeywordScanner(scan_db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result'''
|
||||
|
||||
if old_scan in kw_content:
|
||||
kw_content = kw_content.replace(old_scan, new_scan, 1)
|
||||
print(" OK: Scan endpoint uses dedicated DB session")
|
||||
else:
|
||||
print(" SKIP: Scan endpoint pattern not found")
|
||||
|
||||
# Also fix quick_scan
|
||||
old_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result'''
|
||||
|
||||
new_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
from app.db import async_session_factory
|
||||
async with async_session_factory() as scan_db:
|
||||
scanner = KeywordScanner(scan_db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result'''
|
||||
|
||||
if old_quick in kw_content:
|
||||
kw_content = kw_content.replace(old_quick, new_quick, 1)
|
||||
print(" OK: Quick scan uses dedicated DB session")
|
||||
else:
|
||||
print(" SKIP: Quick scan pattern not found")
|
||||
|
||||
with open(kw_path, "w", encoding="utf-8") as f:
|
||||
f.write(kw_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 3: Scanner service - smaller batches, yield between batches
|
||||
# ================================================================
|
||||
print("\n=== FIX 3: Scanner service - smaller batches + async yield ===")
|
||||
|
||||
scanner_path = os.path.join(ROOT, "backend", "app", "services", "scanner.py")
|
||||
with open(scanner_path, "r", encoding="utf-8") as f:
|
||||
scanner_content = f.read()
|
||||
|
||||
# Change batch size and add yield between batches
|
||||
old_batch = "BATCH_SIZE = 500"
|
||||
new_batch = "BATCH_SIZE = 200"
|
||||
|
||||
if old_batch in scanner_content:
|
||||
scanner_content = scanner_content.replace(old_batch, new_batch, 1)
|
||||
print(" OK: Reduced batch size 500 -> 200")
|
||||
|
||||
# Add asyncio.sleep(0) between batches to yield to other tasks
|
||||
old_batch_loop = ''' offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break'''
|
||||
|
||||
new_batch_loop = ''' offset += BATCH_SIZE
|
||||
# Yield to event loop between batches so other requests aren't starved
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break'''
|
||||
|
||||
if old_batch_loop in scanner_content:
|
||||
scanner_content = scanner_content.replace(old_batch_loop, new_batch_loop, 1)
|
||||
print(" OK: Added async yield between scan batches")
|
||||
else:
|
||||
print(" SKIP: Batch loop pattern not found")
|
||||
|
||||
with open(scanner_path, "w", encoding="utf-8") as f:
|
||||
f.write(scanner_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 4: Job queue workers - increase from 3 to 5
|
||||
# ================================================================
|
||||
print("\n=== FIX 4: Job queue - more workers ===")
|
||||
|
||||
jq_path = os.path.join(ROOT, "backend", "app", "services", "job_queue.py")
|
||||
with open(jq_path, "r", encoding="utf-8") as f:
|
||||
jq_content = f.read()
|
||||
|
||||
old_workers = "job_queue = JobQueue(max_workers=3)"
|
||||
new_workers = "job_queue = JobQueue(max_workers=5)"
|
||||
|
||||
if old_workers in jq_content:
|
||||
jq_content = jq_content.replace(old_workers, new_workers, 1)
|
||||
print(" OK: Workers 3 -> 5")
|
||||
|
||||
with open(jq_path, "w", encoding="utf-8") as f:
|
||||
f.write(jq_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 5: main.py - always re-run pipeline on startup for ALL datasets
|
||||
# ================================================================
|
||||
print("\n=== FIX 5: Startup reprocessing - all datasets, not just 'ready' ===")
|
||||
|
||||
main_path = os.path.join(ROOT, "backend", "app", "main.py")
|
||||
with open(main_path, "r", encoding="utf-8") as f:
|
||||
main_content = f.read()
|
||||
|
||||
# The current startup only reprocesses datasets with status="ready"
|
||||
# But after previous runs, they're all "completed" - so nothing happens
|
||||
# Fix: reprocess datasets that have NO triage/anomaly results in DB
|
||||
old_reprocess = ''' # Reprocess datasets that were never fully processed (status still "ready")
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import Dataset
|
||||
stmt = select(Dataset.id).where(Dataset.processing_status == "ready")
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
for ds_id in unprocessed_ids:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
if unprocessed_ids:
|
||||
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||
# Mark them as processing
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(unprocessed_ids))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()'''
|
||||
|
||||
new_reprocess = ''' # Check which datasets still need processing
|
||||
# (no anomaly results = never fully processed)
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select, exists
|
||||
from app.db.models import Dataset, AnomalyResult
|
||||
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||
has_anomaly = (
|
||||
select(AnomalyResult.id)
|
||||
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||
.limit(1)
|
||||
.correlate(Dataset)
|
||||
.exists()
|
||||
)
|
||||
stmt = select(Dataset.id).where(~has_anomaly)
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
|
||||
if unprocessed_ids:
|
||||
for ds_id in unprocessed_ids:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(unprocessed_ids))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()
|
||||
else:
|
||||
logger.info("All datasets already processed - skipping startup pipeline")'''
|
||||
|
||||
if old_reprocess in main_content:
|
||||
main_content = main_content.replace(old_reprocess, new_reprocess, 1)
|
||||
print(" OK: Startup checks for actual results, not just status field")
|
||||
else:
|
||||
print(" SKIP: Reprocess block not found")
|
||||
|
||||
with open(main_path, "w", encoding="utf-8") as f:
|
||||
f.write(main_content)
|
||||
|
||||
|
||||
print("\n=== ALL FIXES APPLIED ===")
|
||||
Reference in New Issue
Block a user