4 Commits

Author SHA1 Message Date
483176c06b chore: checkpoint all local changes 2026-02-23 14:36:33 -05:00
13bd9ec9e0 feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:35:49 -05:00
5a2ad8ec1c feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:23:07 -05:00
37a9584d0c docs: update changelog and add robust dev-up startup script 2026-02-23 14:22:17 -05:00
114 changed files with 11331 additions and 905 deletions

26
.gitignore vendored
View File

@@ -1,4 +1,4 @@
# ── Python ──────────────────────────────────── # ── Python ────────────────────────────────────
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
@@ -8,34 +8,34 @@ build/
*.egg *.egg
.eggs/ .eggs/
# ── Virtual environments ───────────────────── # ── Virtual environments ─────────────────────
venv/ venv/
.venv/ .venv/
env/ env/
# ── IDE / Editor ───────────────────────────── # ── IDE / Editor ─────────────────────────────
.vscode/ .vscode/
.idea/ .idea/
*.swp *.swp
*.swo *.swo
*~ *~
# ── OS ──────────────────────────────────────── # ── OS ────────────────────────────────────────
.DS_Store .DS_Store
Thumbs.db Thumbs.db
# ── Environment / Secrets ──────────────────── # ── Environment / Secrets ────────────────────
.env .env
*.env.local *.env.local
# ── Database ───────────────────────────────── # ── Database ─────────────────────────────────
*.db *.db
*.sqlite3 *.sqlite3
# ── Uploads ────────────────────────────────── # ── Uploads ──────────────────────────────────
uploads/ uploads/
# ── Node / Frontend ────────────────────────── # ── Node / Frontend ──────────────────────────
node_modules/ node_modules/
frontend/build/ frontend/build/
frontend/.env.local frontend/.env.local
@@ -43,14 +43,18 @@ npm-debug.log*
yarn-debug.log* yarn-debug.log*
yarn-error.log* yarn-error.log*
# ── Docker ─────────────────────────────────── # ── Docker ───────────────────────────────────
docker-compose.override.yml docker-compose.override.yml
# ── Test / Coverage ────────────────────────── # ── Test / Coverage ──────────────────────────
.coverage .coverage
htmlcov/ htmlcov/
.pytest_cache/ .pytest_cache/
.mypy_cache/ .mypy_cache/
# ── Alembic ────────────────────────────────── # ── Alembic ──────────────────────────────────
alembic/versions/*.pyc alembic/versions/*.pyc
*.db-wal
*.db-shm

View File

@@ -17,7 +17,7 @@ COPY frontend/tsconfig.json ./
# Build application # Build application
RUN npm run build RUN npm run build
# Production stage nginx reverse-proxy + static files # Production stage — nginx reverse-proxy + static files
FROM nginx:alpine FROM nginx:alpine
# Copy built React app # Copy built React app

View File

@@ -0,0 +1,148 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# 1) Add label mode type near graph types
marker="interface GEdge { source: string; target: string; weight: number }\ninterface Graph { nodes: GNode[]; edges: GEdge[] }\n"
if marker in t and "type LabelMode" not in t:
t=t.replace(marker, marker+"\ntype LabelMode = 'all' | 'highlight' | 'none';\n")
# 2) extend drawLabels signature
old_sig="""function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
"""
new_sig="""function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean, labelMode: LabelMode,
) {
"""
if old_sig in t:
t=t.replace(old_sig,new_sig)
# 3) label mode guards inside drawLabels
old_guard=""" const dimmed = search.length > 0;
if (simplify && !search && !hovered && !selected) {
return;
}
"""
new_guard=""" if (labelMode === 'none') return;
const dimmed = search.length > 0;
if (labelMode === 'highlight' && !search && !hovered && !selected) return;
if (simplify && labelMode !== 'all' && !search && !hovered && !selected) {
return;
}
"""
if old_guard in t:
t=t.replace(old_guard,new_guard)
old_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
const show = isHighlight || n.meta.type === 'host' || n.count >= 2;
if (!show) continue;
"""
new_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
const show = labelMode === 'all'
? (isHighlight || n.meta.type === 'host' || n.count >= 2)
: isHighlight;
if (!show) continue;
"""
if old_show in t:
t=t.replace(old_show,new_show)
# 4) drawGraph signature and call site
old_graph_sig="""function drawGraph(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null, search: string,
vp: Viewport, animTime: number, dpr: number,
) {
"""
new_graph_sig="""function drawGraph(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null, search: string,
vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode,
) {
"""
if old_graph_sig in t:
t=t.replace(old_graph_sig,new_graph_sig)
old_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);"
new_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode);"
if old_drawlabels_call in t:
t=t.replace(old_drawlabels_call,new_drawlabels_call)
# 5) state for label mode
state_anchor=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n"
state_new=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n const [labelMode, setLabelMode] = useState<LabelMode>('highlight');\n"
if state_anchor in t:
t=t.replace(state_anchor,state_new)
# 6) pass labelMode in draw calls
old_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);"
new_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode);"
if old_tick_draw in t:
t=t.replace(old_tick_draw,new_tick_draw)
old_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);"
new_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode);"
if old_redraw_draw in t:
t=t.replace(old_redraw_draw,new_redraw_draw)
# 7) include labelMode in redraw deps
old_redraw_dep="] , [graph, hovered, selectedNode, search]);"
if old_redraw_dep in t:
t=t.replace(old_redraw_dep, "] , [graph, hovered, selectedNode, search, labelMode]);")
else:
t=t.replace(" }, [graph, hovered, selectedNode, search]);"," }, [graph, hovered, selectedNode, search, labelMode]);")
# 8) Add toolbar selector after search field
search_block=""" <TextField
size="small"
placeholder="Search hosts, IPs, users\u2026"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
"""
label_block=""" <TextField
size="small"
placeholder="Search hosts, IPs, users\u2026"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
<FormControl size="small" sx={{ minWidth: 140 }}>
<InputLabel id="label-mode-selector">Labels</InputLabel>
<Select
labelId="label-mode-selector"
value={labelMode}
label="Labels"
onChange={e => setLabelMode(e.target.value as LabelMode)}
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
>
<MenuItem value="none">None</MenuItem>
<MenuItem value="highlight">Selected/Search</MenuItem>
<MenuItem value="all">All</MenuItem>
</Select>
</FormControl>
"""
if search_block in t:
t=t.replace(search_block,label_block)
p.write_text(t,encoding='utf-8')
print('added network map label filter control and renderer modes')

View File

@@ -0,0 +1,18 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' # -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
'''
new=''' # -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=300000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
if old not in t:
raise SystemExit('scanner settings block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('added SCANNER_MAX_ROWS_PER_SCAN config')

View File

@@ -0,0 +1,46 @@
from pathlib import Path
root = Path(r"d:\Projects\Dev\ThreatHunt")
# -------- client.ts --------
client = root / "frontend/src/api/client.ts"
text = client.read_text(encoding="utf-8")
if "export interface NetworkSummary" not in text:
insert_after = "export interface InventoryStatus {\n hunt_id: string;\n status: 'ready' | 'building' | 'none';\n}\n"
addition = insert_after + "\nexport interface NetworkSummaryHost {\n id: string;\n hostname: string;\n row_count: number;\n ip_count: number;\n user_count: number;\n}\n\nexport interface NetworkSummary {\n stats: InventoryStats;\n top_hosts: NetworkSummaryHost[];\n top_edges: InventoryConnection[];\n status?: 'building' | 'deferred';\n message?: string;\n}\n"
text = text.replace(insert_after, addition)
net_old = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
net_new = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n summary: (huntId: string, topN = 20) =>\n api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),\n subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {\n let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;\n if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;\n return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);\n },\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
if net_old in text:
text = text.replace(net_old, net_new)
client.write_text(text, encoding="utf-8")
# -------- NetworkMap.tsx --------
nm = root / "frontend/src/components/NetworkMap.tsx"
text = nm.read_text(encoding="utf-8")
# add constants
if "LARGE_HUNT_HOST_THRESHOLD" not in text:
text = text.replace("let lastSelectedHuntId = '';\n", "let lastSelectedHuntId = '';\nconst LARGE_HUNT_HOST_THRESHOLD = 400;\nconst LARGE_HUNT_SUBGRAPH_HOSTS = 350;\nconst LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n")
# inject helper in component after sleep
marker = " const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
if "loadScaleAwareGraph" not in text:
helper = marker + "\n const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n setLoading(true); setError(''); setGraph(null); setStats(null);\n setSelectedNode(null); setPopoverAnchor(null);\n\n const waitReadyThen = async <T,>(fn: () => Promise<T>): Promise<T> => {\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const out: any = await fn();\n if (out && !out.status) return out as T;\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') {\n const out2: any = await fn();\n if (out2 && !out2.status) return out2 as T;\n }\n if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes');\n const jitter = Math.floor(Math.random() * 250);\n await sleep(delayMs + jitter);\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n };\n\n try {\n setProgress('Loading network summary');\n const summary: any = await waitReadyThen(() => network.summary(huntId, 20));\n const totalHosts = summary?.stats?.total_hosts || 0;\n\n if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) {\n setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`);\n const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES));\n if (!sub?.hosts || sub.hosts.length === 0) {\n setError('No hosts found for subgraph.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() });\n setGraph(g);\n return;\n }\n\n // Small/medium hunts: load full inventory\n setProgress('Loading host inventory');\n const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh));\n if (!inv?.hosts || inv.hosts.length === 0) {\n setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats || inv.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() });\n setGraph(g);\n } catch (e: any) {\n console.error('[NetworkMap] scale-aware load error:', e);\n setError(e.message || 'Failed to load network data');\n } finally {\n setLoading(false);\n setProgress('');\n }\n }, []);\n"
text = text.replace(marker, helper)
# simplify existing loadGraph function body to delegate
pattern_start = text.find(" // Load host inventory for selected hunt (with cache).")
if pattern_start != -1:
# replace the whole loadGraph useCallback block by simple delegator
import re
block_re = re.compile(r" // Load host inventory for selected hunt \(with cache\)\.[\s\S]*?\n \}, \[\]\); // Stable - reads canvasSizeRef, no state deps\n", re.M)
repl = " // Load graph data for selected hunt (delegates to scale-aware loader).\n const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n if (!huntId) return;\n\n // Check module-level cache first (5 min TTL)\n if (!forceRefresh) {\n const cached = graphCache.get(huntId);\n if (cached && Date.now() - cached.ts < 5 * 60 * 1000) {\n setGraph(cached.graph);\n setStats(cached.stats);\n setError('');\n simAlphaRef.current = 0;\n return;\n }\n }\n\n await loadScaleAwareGraph(huntId, forceRefresh);\n // eslint-disable-next-line react-hooks/exhaustive-deps\n }, []); // Stable - reads canvasSizeRef, no state deps\n"
text = block_re.sub(repl, text, count=1)
nm.write_text(text, encoding="utf-8")
print("Patched frontend client + NetworkMap for scale-aware loading")

206
_apply_phase1_patch.py Normal file
View File

@@ -0,0 +1,206 @@
from pathlib import Path
root = Path(r"d:\Projects\Dev\ThreatHunt")
# 1) config.py additions
cfg = root / "backend/app/config.py"
text = cfg.read_text(encoding="utf-8")
needle = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n"
insert = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n\n # -- Job queue settings ----------------------------------------------\n JOB_QUEUE_MAX_BACKLOG: int = Field(\n default=2000, description=\"Soft cap for queued background jobs\"\n )\n JOB_QUEUE_RETAIN_COMPLETED: int = Field(\n default=3000, description=\"Maximum completed/failed jobs to retain in memory\"\n )\n JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(\n default=60, description=\"How often to run in-memory job cleanup\"\n )\n JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
if needle in text:
text = text.replace(needle, insert)
cfg.write_text(text, encoding="utf-8")
# 2) scanner.py default scope = dataset-only
scanner = root / "backend/app/services/scanner.py"
text = scanner.read_text(encoding="utf-8")
text = text.replace(" scan_hunts: bool = True,", " scan_hunts: bool = False,")
text = text.replace(" scan_annotations: bool = True,", " scan_annotations: bool = False,")
text = text.replace(" scan_messages: bool = True,", " scan_messages: bool = False,")
scanner.write_text(text, encoding="utf-8")
# 3) keywords.py defaults = dataset-only
kw = root / "backend/app/api/routes/keywords.py"
text = kw.read_text(encoding="utf-8")
text = text.replace(" scan_hunts: bool = True", " scan_hunts: bool = False")
text = text.replace(" scan_annotations: bool = True", " scan_annotations: bool = False")
text = text.replace(" scan_messages: bool = True", " scan_messages: bool = False")
kw.write_text(text, encoding="utf-8")
# 4) job_queue.py dedupe + periodic cleanup
jq = root / "backend/app/services/job_queue.py"
text = jq.read_text(encoding="utf-8")
text = text.replace(
"from typing import Any, Callable, Coroutine, Optional\n",
"from typing import Any, Callable, Coroutine, Optional\n\nfrom app.config import settings\n"
)
text = text.replace(
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n",
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n self._cleanup_task: asyncio.Task | None = None\n"
)
start_old = ''' async def start(self):
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
logger.info(f"Job queue started with {self._max_workers} workers")
'''
start_new = ''' async def start(self):
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(f"Job queue started with {self._max_workers} workers")
'''
text = text.replace(start_old, start_new)
stop_old = ''' async def stop(self):
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("Job queue stopped")
'''
stop_new = ''' async def stop(self):
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
if self._cleanup_task:
self._cleanup_task.cancel()
await asyncio.gather(self._cleanup_task, return_exceptions=True)
self._cleanup_task = None
logger.info("Job queue stopped")
'''
text = text.replace(stop_old, stop_new)
submit_old = ''' def submit(self, job_type: JobType, **params) -> Job:
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
return job
'''
submit_new = ''' def submit(self, job_type: JobType, **params) -> Job:
# Soft backpressure: prefer dedupe over queue amplification
dedupe_job = self._find_active_duplicate(job_type, params)
if dedupe_job is not None:
logger.info(
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
)
return dedupe_job
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
logger.warning(
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
)
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
return job
'''
text = text.replace(submit_old, submit_new)
insert_methods_after = " def get_job(self, job_id: str) -> Job | None:\n return self._jobs.get(job_id)\n"
new_methods = ''' def get_job(self, job_id: str) -> Job | None:
return self._jobs.get(job_id)
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
"""Return queued/running job with same key workload to prevent duplicate storms."""
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
if not sig:
return None
for j in self._jobs.values():
if j.job_type != job_type:
continue
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
continue
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
if sig == other_sig:
return j
return None
'''
text = text.replace(insert_methods_after, new_methods)
cleanup_old = ''' def cleanup(self, max_age_seconds: float = 3600):
now = time.time()
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
and (now - j.created_at) > max_age_seconds
]
for jid in to_remove:
del self._jobs[jid]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old jobs")
'''
cleanup_new = ''' def cleanup(self, max_age_seconds: float = 3600):
now = time.time()
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
]
# Also cap retained terminal jobs to avoid unbounded memory growth
terminal_jobs = sorted(
[j for j in self._jobs.values() if j.status in terminal_states],
key=lambda j: j.created_at,
reverse=True,
)
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
to_remove.extend([j.id for j in overflow])
removed = 0
for jid in set(to_remove):
if jid in self._jobs:
del self._jobs[jid]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
async def _cleanup_loop(self):
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
while self._started:
try:
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
except Exception as e:
logger.warning(f"Job queue cleanup loop error: {e}")
await asyncio.sleep(interval)
'''
text = text.replace(cleanup_old, cleanup_new)
jq.write_text(text, encoding="utf-8")
# 5) NetworkMap polling backoff/jitter max wait
nm = root / "frontend/src/components/NetworkMap.tsx"
text = nm.read_text(encoding="utf-8")
text = text.replace(
" // Poll until ready, then re-fetch\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n }\n",
" // Poll until ready (exponential backoff), then re-fetch\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n throw new Error('Host inventory build timed out after 5 minutes');\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n"
)
text = text.replace(
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status every 2s until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n // still building or none (job may not have started yet) - keep polling\n } catch { if (cancelled) return false; }\n }\n };\n",
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status with exponential backoff until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n setError('Host inventory build timed out. Please retry.');\n return false;\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n // still building or none (job may not have started yet) - keep polling\n } catch {\n if (cancelled) return false;\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n }\n };\n"
)
nm.write_text(text, encoding="utf-8")
print("Patched: config.py, scanner.py, keywords.py, job_queue.py, NetworkMap.tsx")

207
_apply_phase2_patch.py Normal file
View File

@@ -0,0 +1,207 @@
from pathlib import Path
import re
root = Path(r"d:\Projects\Dev\ThreatHunt")
# ---------- config.py ----------
cfg = root / "backend/app/config.py"
text = cfg.read_text(encoding="utf-8")
marker = " JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
add = marker + "\n # -- Startup throttling ------------------------------------------------\n STARTUP_WARMUP_MAX_HUNTS: int = Field(\n default=5, description=\"Max hunts to warm inventory cache for at startup\"\n )\n STARTUP_REPROCESS_MAX_DATASETS: int = Field(\n default=25, description=\"Max unprocessed datasets to enqueue at startup\"\n )\n\n # -- Network API scale guards -----------------------------------------\n NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(\n default=400, description=\"Hard cap for hosts returned by network subgraph endpoint\"\n )\n NETWORK_SUBGRAPH_MAX_EDGES: int = Field(\n default=3000, description=\"Hard cap for edges returned by network subgraph endpoint\"\n )\n"
if marker in text and "STARTUP_WARMUP_MAX_HUNTS" not in text:
text = text.replace(marker, add)
cfg.write_text(text, encoding="utf-8")
# ---------- job_queue.py ----------
jq = root / "backend/app/services/job_queue.py"
text = jq.read_text(encoding="utf-8")
# add helper methods after get_stats
anchor = " def get_stats(self) -> dict:\n by_status = {}\n for j in self._jobs.values():\n by_status[j.status.value] = by_status.get(j.status.value, 0) + 1\n return {\n \"total\": len(self._jobs),\n \"queued\": self._queue.qsize(),\n \"by_status\": by_status,\n \"workers\": self._max_workers,\n \"active_workers\": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),\n }\n"
if "def is_backlogged(" not in text:
insert = anchor + "\n def is_backlogged(self) -> bool:\n return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG\n\n def can_accept(self, reserve: int = 0) -> bool:\n return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG\n"
text = text.replace(anchor, insert)
jq.write_text(text, encoding="utf-8")
# ---------- host_inventory.py keyset pagination ----------
hi = root / "backend/app/services/host_inventory.py"
text = hi.read_text(encoding="utf-8")
old = ''' batch_size = 5000
offset = 0
while True:
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.offset(offset).limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
'''
new = ''' batch_size = 5000
last_row_index = -1
while True:
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index)
.limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
'''
if old in text:
text = text.replace(old, new)
text = text.replace(" offset += batch_size\n if len(rows) < batch_size:\n break\n", " last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n")
hi.write_text(text, encoding="utf-8")
# ---------- network.py add summary/subgraph + backpressure ----------
net = root / "backend/app/api/routes/network.py"
text = net.read_text(encoding="utf-8")
text = text.replace("from fastapi import APIRouter, Depends, HTTPException, Query", "from fastapi import APIRouter, Depends, HTTPException, Query")
if "from app.config import settings" not in text:
text = text.replace("from app.db import get_db\n", "from app.config import settings\nfrom app.db import get_db\n")
# add helpers and endpoints before inventory-status endpoint
if "def _build_summary" not in text:
helper_block = '''
def _build_summary(inv: dict, top_n: int = 20) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
return {
"stats": inv.get("stats", {}),
"top_hosts": [
{
"id": h.get("id"),
"hostname": h.get("hostname"),
"row_count": h.get("row_count", 0),
"ip_count": len(h.get("ips", [])),
"user_count": len(h.get("users", [])),
}
for h in top_hosts
],
"top_edges": top_edges,
}
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
if node_id:
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
ids = {node_id}
for c in rel_edges:
ids.add(c.get("source"))
ids.add(c.get("target"))
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
else:
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
allowed = {h.get("id") for h in rel_hosts}
rel_edges = [
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
if c.get("source") in allowed and c.get("target") in allowed
][:max_edges]
return {
"hosts": rel_hosts,
"connections": rel_edges,
"stats": {
**inv.get("stats", {}),
"subgraph_hosts": len(rel_hosts),
"subgraph_connections": len(rel_edges),
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
},
}
@router.get("/summary")
async def get_inventory_summary(
hunt_id: str = Query(..., description="Hunt ID"),
top_n: int = Query(20, ge=1, le=200),
):
"""Return a lightweight summary view for large hunts."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_summary(cached, top_n=top_n)
@router.get("/subgraph")
async def get_inventory_subgraph(
hunt_id: str = Query(..., description="Hunt ID"),
node_id: str | None = Query(None, description="Optional focal node"),
max_hosts: int = Query(200, ge=1, le=5000),
max_edges: int = Query(1500, ge=1, le=20000),
):
"""Return a bounded subgraph for scale-safe rendering."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
'''
text = text.replace("\n\n@router.get(\"/inventory-status\")", helper_block + "\n\n@router.get(\"/inventory-status\")")
# add backpressure in host-inventory enqueue points
text = text.replace(
" if not inventory_cache.is_building(hunt_id):\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
" if not inventory_cache.is_building(hunt_id):\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
)
text = text.replace(
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
)
net.write_text(text, encoding="utf-8")
# ---------- analysis.py backpressure on manual submit ----------
analysis = root / "backend/app/api/routes/analysis.py"
text = analysis.read_text(encoding="utf-8")
text = text.replace(
" job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}",
" if not job_queue.can_accept():\n raise HTTPException(status_code=429, detail=\"Job queue is busy. Retry shortly.\")\n job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}"
)
analysis.write_text(text, encoding="utf-8")
# ---------- main.py startup throttles ----------
main = root / "backend/app/main.py"
text = main.read_text(encoding="utf-8")
text = text.replace(
" for hid in hunt_ids:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if hunt_ids:\n logger.info(f\"Queued host inventory warm-up for {len(hunt_ids)} hunts\")",
" warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]\n for hid in warm_hunts:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if warm_hunts:\n logger.info(f\"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})\")"
)
text = text.replace(
" if unprocessed_ids:\n for ds_id in unprocessed_ids:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(unprocessed_ids))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()",
" if unprocessed_ids:\n to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]\n for ds_id in to_reprocess:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(to_reprocess))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()"
)
main.write_text(text, encoding="utf-8")
print("Patched Phase 2 files")

View File

@@ -0,0 +1,75 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
# default selection when hunt changes: first 3 datasets instead of all
old=''' datasets.list(0, 500, selectedHuntId).then(res => {
if (cancelled) return;
setDsList(res.datasets);
setSelectedDs(new Set(res.datasets.map(d => d.id)));
}).catch(() => {});
'''
new=''' datasets.list(0, 500, selectedHuntId).then(res => {
if (cancelled) return;
setDsList(res.datasets);
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
}).catch(() => {});
'''
if old not in t:
raise SystemExit('hunt-change dataset init block not found')
t=t.replace(old,new)
# insert dataset scope multi-select under hunt info
anchor=''' {!selectedHuntId && (
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
All datasets will be scanned if no hunt is selected
</Typography>
)}
</Box>
{/* Theme selector */}
'''
insert=''' {!selectedHuntId && (
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
Select a hunt to enable scoped scanning
</Typography>
)}
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
<Select
labelId="aup-dataset-label"
multiple
value={Array.from(selectedDs)}
label="Datasets"
renderValue={(selected) => `${(selected as string[]).length} selected`}
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
>
{dsList.map(d => (
<MenuItem key={d.id} value={d.id}>
<Checkbox size="small" checked={selectedDs.has(d.id)} />
<Typography variant="body2" sx={{ ml: 0.5 }}>
{d.name} ({d.row_count.toLocaleString()} rows)
</Typography>
</MenuItem>
))}
</Select>
</FormControl>
{selectedHuntId && dsList.length > 0 && (
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
</Stack>
)}
</Box>
{/* Theme selector */}
'''
if anchor not in t:
raise SystemExit('dataset scope anchor not found')
t=t.replace(anchor,insert)
p.write_text(t,encoding='utf-8')
print('added AUP dataset multi-select scoping and safer defaults')

View File

@@ -0,0 +1,182 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
# 1) Extend ScanHit dataclass
old='''@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
'''
new='''@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
'''
if old not in t:
raise SystemExit('ScanHit dataclass block not found')
t=t.replace(old,new)
# 2) Add helper to infer hostname/user from a row
insert_after='''BATCH_SIZE = 200
@dataclass
class ScanHit:
'''
helper='''BATCH_SIZE = 200
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
"""Best-effort extraction of hostname and user from a dataset row."""
if not data:
return None, None
host_keys = (
'hostname', 'host_name', 'host', 'computer_name', 'computer',
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
)
user_keys = (
'username', 'user_name', 'user', 'account_name',
'logged_in_user', 'samaccountname', 'sam_account_name',
)
def pick(keys):
for k in keys:
for actual_key, v in data.items():
if actual_key.lower() == k and v not in (None, ''):
return str(v)
return None
return pick(host_keys), pick(user_keys)
@dataclass
class ScanHit:
'''
if insert_after in t and '_infer_hostname_and_user' not in t:
t=t.replace(insert_after,helper)
# 3) Extend _match_text signature and ScanHit construction
old_sig=''' def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
) -> None:
'''
new_sig=''' def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
hostname: str | None = None,
username: str | None = None,
) -> None:
'''
if old_sig not in t:
raise SystemExit('_match_text signature not found')
t=t.replace(old_sig,new_sig)
old_hit=''' hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
))
'''
new_hit=''' hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
hostname=hostname,
username=username,
))
'''
if old_hit not in t:
raise SystemExit('ScanHit append block not found')
t=t.replace(old_hit,new_hit)
# 4) Pass inferred hostname/username in dataset scan path
old_call=''' for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
)
'''
new_call=''' for row in rows:
result.rows_scanned += 1
data = row.data or {}
hostname, username = _infer_hostname_and_user(data)
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
hostname=hostname,
username=username,
)
'''
if old_call not in t:
raise SystemExit('dataset _match_text call block not found')
t=t.replace(old_call,new_call)
p.write_text(t,encoding='utf-8')
print('updated scanner hits with hostname+username context')

View File

@@ -0,0 +1,32 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old='''class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
'''
new='''class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
'''
if old not in t:
raise SystemExit('ScanHit pydantic model block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('extended API ScanHit model with hostname+username')

View File

@@ -0,0 +1,21 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
t=p.read_text(encoding='utf-8')
old='''export interface ScanHit {
theme_name: string; theme_color: string; keyword: string;
source_type: string; source_id: string | number; field: string;
matched_value: string; row_index: number | null; dataset_name: string | null;
}
'''
new='''export interface ScanHit {
theme_name: string; theme_color: string; keyword: string;
source_type: string; source_id: string | number; field: string;
matched_value: string; row_index: number | null; dataset_name: string | null;
hostname?: string | null; username?: string | null;
}
'''
if old not in t:
raise SystemExit('frontend ScanHit interface block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('extended frontend ScanHit type with hostname+username')

View File

@@ -0,0 +1,57 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
# add fast guard against unscoped global dataset scans
insert_after='''async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):\n scanner = KeywordScanner(db)\n\n'''
if insert_after not in t:
raise SystemExit('run_scan header block not found')
if 'Select at least one dataset' not in t:
guard=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:\n raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")\n\n'''
t=t.replace(insert_after, insert_after+guard)
old=''' if missing:
missing_entries: list[dict] = []
for dataset_id in missing:
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
keyword_scan_cache.put(dataset_id, partial)
missing_entries.append({"result": partial, "built_at": None})
merged = _merge_cached_results(
cached_entries + missing_entries,
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
new=''' if missing:
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
merged = _merge_cached_results(
cached_entries + [{"result": partial, "built_at": None}],
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
if old not in t:
raise SystemExit('partial-cache missing block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('hardened keywords scan scope + optimized missing-cache path')

18
_aup_reduce_budget.py Normal file
View File

@@ -0,0 +1,18 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=300000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
new=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=120000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
if old not in t:
raise SystemExit('SCANNER_MAX_ROWS_PER_SCAN block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('reduced SCANNER_MAX_ROWS_PER_SCAN default to 120000')

View File

@@ -0,0 +1,42 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
old='''const RESULT_COLUMNS: GridColDef[] = [
{
field: 'theme_name', headerName: 'Theme', width: 140,
renderCell: (params) => (
<Chip label={params.value} size="small"
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
),
},
{ field: 'keyword', headerName: 'Keyword', width: 140 },
{ field: 'source_type', headerName: 'Source', width: 120 },
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
{ field: 'field', headerName: 'Field', width: 130 },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
];
'''
new='''const RESULT_COLUMNS: GridColDef[] = [
{
field: 'theme_name', headerName: 'Theme', width: 140,
renderCell: (params) => (
<Chip label={params.value} size="small"
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
),
},
{ field: 'keyword', headerName: 'Keyword', width: 140 },
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
{ field: 'field', headerName: 'Field', width: 130 },
{ field: 'source_type', headerName: 'Source', width: 120 },
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
];
'''
if old not in t:
raise SystemExit('RESULT_COLUMNS block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated AUP results grid columns with dataset/hostname/user/matched value focus')

40
_edit_aup.py Normal file
View File

@@ -0,0 +1,40 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
t=t.replace(' const [scanHunts, setScanHunts] = useState(true);',' const [scanHunts, setScanHunts] = useState(false);')
t=t.replace(' const [scanAnnotations, setScanAnnotations] = useState(true);',' const [scanAnnotations, setScanAnnotations] = useState(false);')
t=t.replace(' const [scanMessages, setScanMessages] = useState(true);',' const [scanMessages, setScanMessages] = useState(false);')
t=t.replace(' scan_messages: scanMessages,\n });',' scan_messages: scanMessages,\n prefer_cache: true,\n });')
# add cache chip in summary alert
old=''' {scanResult && (
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
<strong>{scanResult.total_hits}</strong> hits across{' '}
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
</Alert>
)}
'''
new=''' {scanResult && (
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
<strong>{scanResult.total_hits}</strong> hits across{' '}
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
{scanResult.cache_status && (
<Chip
size="small"
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
sx={{ ml: 1, height: 20 }}
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
variant="outlined"
/>
)}
</Alert>
)}
'''
if old in t:
t=t.replace(old,new)
else:
print('warning: summary block not replaced')
p.write_text(t,encoding='utf-8')
print('updated AUPScanner.tsx')

36
_edit_client.py Normal file
View File

@@ -0,0 +1,36 @@
from pathlib import Path
import re
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
t=p.read_text(encoding='utf-8')
# Add HuntProgress interface after Hunt interface
if 'export interface HuntProgress' not in t:
insert = '''export interface HuntProgress {
hunt_id: string;
status: 'idle' | 'processing' | 'ready';
progress_percent: number;
dataset_total: number;
dataset_completed: number;
dataset_processing: number;
dataset_errors: number;
active_jobs: number;
queued_jobs: number;
network_status: 'none' | 'building' | 'ready';
stages: Record<string, any>;
}
'''
t=t.replace('export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n', 'export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n'+insert)
# Add hunts.progress method
if 'progress: (id: string)' not in t:
t=t.replace(" delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n};", " delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),\n};")
# Extend ScanResponse
if 'cache_used?: boolean' not in t:
t=t.replace('export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n}\n', 'export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n cache_used?: boolean; cache_status?: string; cached_at?: string | null;\n}\n')
# Extend keywords.scan opts
t=t.replace(' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n }) =>', ' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n prefer_cache?: boolean; force_rescan?: boolean;\n }) =>')
p.write_text(t,encoding='utf-8')
print('updated client.ts')

20
_edit_config_reconcile.py Normal file
View File

@@ -0,0 +1,20 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
anchor=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
'''
insert=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
default=True,
description="Mark stale queued/running processing tasks as failed on startup",
)
'''
if anchor not in t:
raise SystemExit('startup anchor not found')
t=t.replace(anchor,insert)
p.write_text(t,encoding='utf-8')
print('updated config with STARTUP_RECONCILE_STALE_TASKS')

39
_edit_datasets.py Normal file
View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'from app.services.scanner import keyword_scan_cache' not in t:
t=t.replace('from app.services.host_inventory import inventory_cache','from app.services.host_inventory import inventory_cache\nfrom app.services.scanner import keyword_scan_cache')
old='''@router.delete(
"/{dataset_id}",
summary="Delete a dataset",
)
async def delete_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
return {"message": "Dataset deleted", "id": dataset_id}
'''
new='''@router.delete(
"/{dataset_id}",
summary="Delete a dataset",
)
async def delete_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id}
'''
if old not in t:
raise SystemExit('delete block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated datasets.py')

110
_edit_datasets_tasks.py Normal file
View File

@@ -0,0 +1,110 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'ProcessingTask' not in t:
t=t.replace('from app.db.models import', 'from app.db.models import ProcessingTask\n# from app.db.models import')
t=t.replace('from app.services.scanner import keyword_scan_cache','from app.services.scanner import keyword_scan_cache')
# clean import replacement to proper single line
if '# from app.db.models import' in t:
t=t.replace('from app.db.models import ProcessingTask\n# from app.db.models import', 'from app.db.models import ProcessingTask')
old=''' # 1. AI Triage (chains to HOST_PROFILE automatically on completion)
job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
# 2. Anomaly detection (embedding-based outlier detection)
job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
# 3. AUP keyword scan
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
# 4. IOC extraction
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
'''
new=''' task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
'''
if old not in t:
raise SystemExit('queue block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated datasets upload queue + processing tasks')

254
_edit_hunts.py Normal file
View File

@@ -0,0 +1,254 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
new='''"""API routes for hunt management."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Hunt, Dataset
from app.services.job_queue import job_queue
from app.services.host_inventory import inventory_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
class HuntCreate(BaseModel):
name: str = Field(..., max_length=256)
description: str | None = None
class HuntUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None
class HuntResponse(BaseModel):
id: str
name: str
description: str | None
status: str
owner_id: str | None
created_at: str
updated_at: str
dataset_count: int = 0
hypothesis_count: int = 0
class HuntListResponse(BaseModel):
hunts: list[HuntResponse]
total: int
class HuntProgressResponse(BaseModel):
hunt_id: str
status: str
progress_percent: float
dataset_total: int
dataset_completed: int
dataset_processing: int
dataset_errors: int
active_jobs: int
queued_jobs: int
network_status: str
stages: dict
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
hunt = Hunt(name=body.name, description=body.description)
db.add(hunt)
await db.flush()
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
)
@router.get("", response_model=HuntListResponse, summary="List hunts")
async def list_hunts(
status: str | None = Query(None),
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
if status:
stmt = stmt.where(Hunt.status == status)
stmt = stmt.limit(limit).offset(offset)
result = await db.execute(stmt)
hunts = result.scalars().all()
count_stmt = select(func.count(Hunt.id))
if status:
count_stmt = count_stmt.where(Hunt.status == status)
total = (await db.execute(count_stmt)).scalar_one()
return HuntListResponse(
hunts=[
HuntResponse(
id=h.id,
name=h.name,
description=h.description,
status=h.status,
owner_id=h.owner_id,
created_at=h.created_at.isoformat(),
updated_at=h.updated_at.isoformat(),
dataset_count=len(h.datasets) if h.datasets else 0,
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
)
for h in hunts
],
total=total,
)
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
)
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
hunt = await db.get(Hunt, hunt_id)
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
ds_rows = await db.execute(
select(Dataset.id, Dataset.processing_status)
.where(Dataset.hunt_id == hunt_id)
)
datasets = ds_rows.all()
dataset_ids = {row[0] for row in datasets}
dataset_total = len(datasets)
dataset_completed = sum(1 for _, st in datasets if st == "completed")
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
if inventory_cache.get(hunt_id) is not None:
network_status = "ready"
network_ratio = 1.0
elif inventory_cache.is_building(hunt_id):
network_status = "building"
network_ratio = 0.5
else:
network_status = "none"
network_ratio = 0.0
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
status = "ready"
if dataset_total == 0:
status = "idle"
elif progress_percent < 100:
status = "processing"
stages = {
"datasets": {
"total": dataset_total,
"completed": dataset_completed,
"processing": dataset_processing,
"errors": dataset_errors,
"percent": round(dataset_ratio * 100.0, 1),
},
"network": {
"status": network_status,
"percent": round(network_ratio * 100.0, 1),
},
"jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
},
}
return HuntProgressResponse(
hunt_id=hunt_id,
status=status,
progress_percent=progress_percent,
dataset_total=dataset_total,
dataset_completed=dataset_completed,
dataset_processing=dataset_processing,
dataset_errors=dataset_errors,
active_jobs=active_jobs,
queued_jobs=queued_jobs,
network_status=network_status,
stages=stages,
)
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
async def update_hunt(
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
if body.name is not None:
hunt.name = body.name
if body.description is not None:
hunt.description = body.description
if body.status is not None:
hunt.status = body.status
await db.flush()
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
)
@router.delete("/{hunt_id}", summary="Delete a hunt")
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
await db.delete(hunt)
return {"message": "Hunt deleted", "id": hunt_id}
'''
p.write_text(new,encoding='utf-8')
print('updated hunts.py')

View File

@@ -0,0 +1,102 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
t=p.read_text(encoding='utf-8')
if 'ProcessingTask' not in t:
t=t.replace('from app.db.models import Hunt, Dataset','from app.db.models import Hunt, Dataset, ProcessingTask')
old=''' jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
if inventory_cache.get(hunt_id) is not None:
'''
new=''' jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
task_rows = await db.execute(
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
.where(ProcessingTask.hunt_id == hunt_id)
)
tasks = task_rows.all()
task_total = len(tasks)
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
task_running = sum(1 for _, st, _ in tasks if st == "running")
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
task_ratio = (task_done / task_total) if task_total > 0 else None
active_jobs = max(active_jobs_mem, task_running)
queued_jobs = max(queued_jobs_mem, task_queued)
stage_rollup: dict[str, dict] = {}
for stage, status, progress in tasks:
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
bucket["total"] += 1
if status in ("completed", "failed", "cancelled"):
bucket["done"] += 1
elif status == "running":
bucket["running"] += 1
elif status == "queued":
bucket["queued"] += 1
bucket["progress_sum"] += float(progress or 0.0)
for stage_name, bucket in stage_rollup.items():
total = max(1, bucket["total"])
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
if inventory_cache.get(hunt_id) is not None:
'''
if old not in t:
raise SystemExit('job block not found')
t=t.replace(old,new)
old2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
'''
new2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
if task_ratio is None:
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
else:
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
'''
if old2 not in t:
raise SystemExit('ratio block not found')
t=t.replace(old2,new2)
old3=''' "jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
},
}
'''
new3=''' "jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
"task_total": task_total,
"task_done": task_done,
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
},
"task_stages": stage_rollup,
}
'''
if old3 not in t:
raise SystemExit('stages jobs block not found')
t=t.replace(old3,new3)
p.write_text(t,encoding='utf-8')
print('updated hunt progress to merge persistent processing tasks')

46
_edit_job_queue.py Normal file
View File

@@ -0,0 +1,46 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
old='''async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
'''
new='''async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner, keyword_scan_cache
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
# Cache dataset-only result for fast API reuse
if dataset_id:
keyword_scan_cache.put(dataset_id, result)
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
'''
if old not in t:
raise SystemExit('target block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated job_queue keyword scan handler')

View File

@@ -0,0 +1,13 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
marker='''def register_all_handlers():
"""Register all job handlers and completion callbacks."""
'''
ins='''\n\nasync def reconcile_stale_processing_tasks() -> int:\n """Mark queued/running processing tasks from prior runs as failed."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n now = datetime.now(timezone.utc)\n async with async_session_factory() as db:\n result = await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.status.in_([\"queued\", \"running\"]))\n .values(\n status=\"failed\",\n error=\"Recovered after service restart before task completion\",\n message=\"Recovered stale task after restart\",\n completed_at=now,\n )\n )\n await db.commit()\n updated = int(result.rowcount or 0)\n\n if updated:\n logger.warning(\n \"Reconciled %d stale processing tasks (queued/running -> failed) during startup\",\n updated,\n )\n return updated\n except Exception as e:\n logger.warning(f\"Failed to reconcile stale processing tasks: {e}\")\n return 0\n\n\n'''
if ins.strip() not in t:
if marker not in t:
raise SystemExit('register marker not found')
t=t.replace(marker,ins+marker)
p.write_text(t,encoding='utf-8')
print('added reconcile_stale_processing_tasks to job_queue')

64
_edit_jobqueue_sync.py Normal file
View File

@@ -0,0 +1,64 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
ins='''\n\nasync def _sync_processing_task(job: Job):\n """Persist latest job state into processing_tasks (if linked by job_id)."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n values = {\n "status": job.status.value,\n "progress": float(job.progress),\n "message": job.message,\n "error": job.error,\n }\n if job.started_at:\n values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)\n if job.completed_at:\n values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)\n\n async with async_session_factory() as db:\n await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.job_id == job.id)\n .values(**values)\n )\n await db.commit()\n except Exception as e:\n logger.warning(f"Failed to sync processing task for job {job.id}: {e}")\n'''
marker='\n\n# -- Singleton + job handlers --\n'
if ins.strip() not in t:
t=t.replace(marker, ins+marker)
old=''' job.status = JobStatus.RUNNING
job.started_at = time.time()
job.message = "Running..."
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
'''
new=''' job.status = JobStatus.RUNNING
job.started_at = time.time()
if job.progress <= 0:
job.progress = 5.0
job.message = "Running..."
await _sync_processing_task(job)
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
'''
if old not in t:
raise SystemExit('worker running block not found')
t=t.replace(old,new)
old2=''' job.completed_at = time.time()
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
# Fire completion callbacks
'''
new2=''' job.completed_at = time.time()
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
if job.is_cancelled and not job.completed_at:
job.completed_at = time.time()
await _sync_processing_task(job)
# Fire completion callbacks
'''
if old2 not in t:
raise SystemExit('worker completion block not found')
t=t.replace(old2,new2)
p.write_text(t, encoding='utf-8')
print('updated job_queue persistent task syncing')

View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
old=''' if hunt_id:
job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
'''
new=''' if hunt_id:
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
try:
from sqlalchemy import select
from app.db.models import ProcessingTask
async with async_session_factory() as db:
existing = await db.execute(
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
)
if existing.first() is None:
db.add(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset_id,
job_id=hp_job.id,
stage="host_profile",
status="queued",
progress=0.0,
message="Queued",
))
await db.commit()
except Exception as persist_err:
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
'''
if old not in t:
raise SystemExit('triage chain block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated triage chain to persist host_profile task row')

321
_edit_keywords.py Normal file
View File

@@ -0,0 +1,321 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
new_text='''"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import KeywordTheme, Keyword
from app.services.scanner import KeywordScanner, keyword_scan_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
class ThemeCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=128)
color: str = Field(default="#9e9e9e", max_length=16)
enabled: bool = True
class ThemeUpdate(BaseModel):
name: str | None = None
color: str | None = None
enabled: bool | None = None
class KeywordOut(BaseModel):
id: int
theme_id: str
value: str
is_regex: bool
created_at: str
class ThemeOut(BaseModel):
id: str
name: str
color: str
enabled: bool
is_builtin: bool
created_at: str
keyword_count: int
keywords: list[KeywordOut]
class ThemeListResponse(BaseModel):
themes: list[ThemeOut]
total: int
class KeywordCreate(BaseModel):
value: str = Field(..., min_length=1, max_length=256)
is_regex: bool = False
class KeywordBulkCreate(BaseModel):
values: list[str] = Field(..., min_items=1)
is_regex: bool = False
class ScanRequest(BaseModel):
dataset_ids: list[str] | None = None
theme_ids: list[str] | None = None
scan_hunts: bool = False
scan_annotations: bool = False
scan_messages: bool = False
prefer_cache: bool = True
force_rescan: bool = False
class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
class ScanResponse(BaseModel):
total_hits: int
hits: list[ScanHit]
themes_scanned: int
keywords_scanned: int
rows_scanned: int
cache_used: bool = False
cache_status: str = "miss"
cached_at: str | None = None
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
return ThemeOut(
id=t.id,
name=t.name,
color=t.color,
enabled=t.enabled,
is_builtin=t.is_builtin,
created_at=t.created_at.isoformat(),
keyword_count=len(t.keywords),
keywords=[
KeywordOut(
id=k.id,
theme_id=k.theme_id,
value=k.value,
is_regex=k.is_regex,
created_at=k.created_at.isoformat(),
)
for k in t.keywords
],
)
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
hits: list[dict] = []
total_rows = 0
cached_at: str | None = None
for entry in entries:
result = entry["result"]
total_rows += int(result.get("rows_scanned", 0) or 0)
if entry.get("built_at"):
if not cached_at or entry["built_at"] > cached_at:
cached_at = entry["built_at"]
for h in result.get("hits", []):
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
continue
hits.append(h)
return {
"total_hits": len(hits),
"hits": hits,
"rows_scanned": total_rows,
"cached_at": cached_at,
}
@router.get("/themes", response_model=ThemeListResponse)
async def list_themes(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
themes = result.scalars().all()
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
@router.post("/themes", response_model=ThemeOut, status_code=201)
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
if exists:
raise HTTPException(409, f"Theme '{body.name}' already exists")
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
db.add(theme)
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.put("/themes/{theme_id}", response_model=ThemeOut)
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
if body.name is not None:
dup = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
)
if dup:
raise HTTPException(409, f"Theme '{body.name}' already exists")
theme.name = body.name
if body.color is not None:
theme.color = body.color
if body.enabled is not None:
theme.enabled = body.enabled
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.delete("/themes/{theme_id}", status_code=204)
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
await db.delete(theme)
keyword_scan_cache.clear()
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
db.add(kw)
await db.flush()
await db.refresh(kw)
keyword_scan_cache.clear()
return KeywordOut(
id=kw.id, theme_id=kw.theme_id, value=kw.value,
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
)
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
added = 0
for val in body.values:
val = val.strip()
if not val:
continue
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
added += 1
await db.flush()
keyword_scan_cache.clear()
return {"added": added, "theme_id": theme_id}
@router.delete("/keywords/{keyword_id}", status_code=204)
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
kw = await db.get(Keyword, keyword_id)
if not kw:
raise HTTPException(404, "Keyword not found")
await db.delete(kw)
keyword_scan_cache.clear()
@router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
scanner = KeywordScanner(db)
can_use_cache = (
body.prefer_cache
and not body.force_rescan
and bool(body.dataset_ids)
and not body.scan_hunts
and not body.scan_annotations
and not body.scan_messages
)
if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
result = await scanner.scan(
dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids,
scan_hunts=body.scan_hunts,
scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages,
)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
@router.get("/scan/quick", response_model=ScanResponse)
async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db),
):
entry = keyword_scan_cache.get(dataset_id)
if entry is not None:
result = entry.result
return {
**result,
"cache_used": True,
"cache_status": "hit",
"cached_at": entry.built_at,
}
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
keyword_scan_cache.put(dataset_id, result)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
'''
p.write_text(new_text,encoding='utf-8')
print('updated keywords.py')

31
_edit_main_reconcile.py Normal file
View File

@@ -0,0 +1,31 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/main.py')
t=p.read_text(encoding='utf-8')
old=''' # Start job queue
from app.services.job_queue import job_queue, register_all_handlers, JobType
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
'''
new=''' # Start job queue
from app.services.job_queue import (
job_queue,
register_all_handlers,
reconcile_stale_processing_tasks,
JobType,
)
if settings.STARTUP_RECONCILE_STALE_TASKS:
reconciled = await reconcile_stale_processing_tasks()
if reconciled:
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
'''
if old not in t:
raise SystemExit('startup queue block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('wired startup reconciliation in main lifespan')

View File

@@ -0,0 +1,45 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/db/models.py')
t=p.read_text(encoding='utf-8')
if 'class ProcessingTask(Base):' in t:
print('processing task model already exists')
raise SystemExit(0)
insert='''
# -- Persistent Processing Tasks (Phase 2) ---
class ProcessingTask(Base):
__tablename__ = "processing_tasks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
)
dataset_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
)
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
progress: Mapped[float] = mapped_column(Float, default=0.0)
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
)
'''
# insert before Playbook section
marker='\n\n# -- Playbook / Investigation Templates (Feature 3) ---\n'
if marker not in t:
raise SystemExit('marker not found for insertion')
t=t.replace(marker, insert+marker)
p.write_text(t,encoding='utf-8')
print('added ProcessingTask model')

59
_edit_networkmap_hit.py Normal file
View File

@@ -0,0 +1,59 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
insert='''
function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean {
const fontSize = Math.max(9, Math.round(12 / vp.scale));
const approxCharW = Math.max(5, fontSize * 0.58);
const line1 = node.label || '';
const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : '';
const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0);
const px = 5, py = 2;
const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2;
const lx = node.x, ly = node.y - node.radius - 6;
const rx = lx - tw / 2 - px;
const ry = ly - totalH;
const rw = tw + px * 2;
const rh = totalH;
return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh);
}
'''
if 'function isPointOnNodeLabel' not in t:
t=t.replace('// == Hit-test =============================================================\n', '// == Hit-test =============================================================\n'+insert)
old='''function hitTest(
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): GNode | null {
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
for (const n of graph.nodes) {
const dx = n.x - wx, dy = n.y - wy;
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
}
return null;
}
'''
new='''function hitTest(
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): GNode | null {
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
// Node-circle hit has priority
for (const n of graph.nodes) {
const dx = n.x - wx, dy = n.y - wy;
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
}
// Then label hit (so clicking text works too)
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
return null;
}
'''
if old not in t:
raise SystemExit('hitTest block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated NetworkMap hit-test for labels')

272
_edit_scanner.py Normal file
View File

@@ -0,0 +1,272 @@
from pathlib import Path
p = Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
text = p.read_text(encoding='utf-8')
new_text = '''"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
messages for keyword matches.
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
across SQLite / PostgreSQL and to provide per-cell match context.
"""
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import (
KeywordTheme,
DatasetRow,
Dataset,
Hunt,
Annotation,
Message,
)
logger = logging.getLogger(__name__)
BATCH_SIZE = 200
@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
@dataclass
class ScanResult:
total_hits: int = 0
hits: list[ScanHit] = field(default_factory=list)
themes_scanned: int = 0
keywords_scanned: int = 0
rows_scanned: int = 0
@dataclass
class KeywordScanCacheEntry:
dataset_id: str
result: dict
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
class KeywordScanCache:
"""In-memory per-dataset cache for dataset-only keyword scans.
This enables fast-path reads when users run AUP scans against datasets that
were already scanned during upload pipeline processing.
"""
def __init__(self):
self._entries: dict[str, KeywordScanCacheEntry] = {}
def put(self, dataset_id: str, result: dict):
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
return self._entries.get(dataset_id)
def invalidate_dataset(self, dataset_id: str):
self._entries.pop(dataset_id, None)
def clear(self):
self._entries.clear()
keyword_scan_cache = KeywordScanCache()
class KeywordScanner:
"""Scans multiple data sources for keyword/regex matches."""
def __init__(self, db: AsyncSession):
self.db = db
# Public API
async def scan(
self,
dataset_ids: list[str] | None = None,
theme_ids: list[str] | None = None,
scan_hunts: bool = False,
scan_annotations: bool = False,
scan_messages: bool = False,
) -> dict:
"""Run a full AUP scan and return dict matching ScanResponse."""
# Load themes + keywords
themes = await self._load_themes(theme_ids)
if not themes:
return ScanResult().__dict__
# Pre-compile patterns per theme
patterns = self._compile_patterns(themes)
result = ScanResult(
themes_scanned=len(themes),
keywords_scanned=sum(len(kws) for kws in patterns.values()),
)
# Scan dataset rows
await self._scan_datasets(patterns, result, dataset_ids)
# Scan hunts
if scan_hunts:
await self._scan_hunts(patterns, result)
# Scan annotations
if scan_annotations:
await self._scan_annotations(patterns, result)
# Scan messages
if scan_messages:
await self._scan_messages(patterns, result)
result.total_hits = len(result.hits)
return {
"total_hits": result.total_hits,
"hits": [h.__dict__ for h in result.hits],
"themes_scanned": result.themes_scanned,
"keywords_scanned": result.keywords_scanned,
"rows_scanned": result.rows_scanned,
}
# Internal
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
if theme_ids:
q = q.where(KeywordTheme.id.in_(theme_ids))
result = await self.db.execute(q)
return list(result.scalars().all())
def _compile_patterns(
self, themes: list[KeywordTheme]
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
for theme in themes:
key = (theme.id, theme.name, theme.color)
compiled = []
for kw in theme.keywords:
try:
if kw.is_regex:
pat = re.compile(kw.value, re.IGNORECASE)
else:
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
compiled.append((kw.value, pat))
except re.error:
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
kw.value, theme.name)
patterns[key] = compiled
return patterns
def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
) -> None:
"""Check text against all compiled patterns, append hits."""
if not text:
return
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
for kw_value, pat in keyword_patterns:
if pat.search(text):
matched_preview = text[:200] + ("" if len(text) > 200 else "")
hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
))
async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches."""
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
ds_result = await self.db.execute(ds_q)
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
if not ds_map:
return
offset = 0
row_q_base = select(DatasetRow).where(
DatasetRow.dataset_id.in_(list(ds_map.keys()))
).order_by(DatasetRow.id)
while True:
rows_result = await self.db.execute(
row_q_base.offset(offset).limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text, patterns, "dataset_row", row.id,
col_name, result.hits,
row_index=row.row_index,
dataset_name=ds_map.get(row.dataset_id),
)
offset += BATCH_SIZE
import asyncio
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
"""Scan hunt names and descriptions."""
hunts_result = await self.db.execute(select(Hunt))
for hunt in hunts_result.scalars().all():
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
if hunt.description:
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
"""Scan annotation text."""
ann_result = await self.db.execute(select(Annotation))
for ann in ann_result.scalars().all():
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
"""Scan conversation messages (user messages only)."""
msg_result = await self.db.execute(
select(Message).where(Message.role == "user")
)
for msg in msg_result.scalars().all():
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
'''
p.write_text(new_text, encoding='utf-8')
print('updated scanner.py')

31
_edit_test_api.py Normal file
View File

@@ -0,0 +1,31 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_api.py')
t=p.read_text(encoding='utf-8')
insert='''
async def test_hunt_progress(self, client):
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
hunt_id = create.json()["id"]
# attach one dataset so progress has scope
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
res = await client.get(f"/api/hunts/{hunt_id}/progress")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert "progress_percent" in body
assert "dataset_total" in body
assert "network_status" in body
'''
needle=''' async def test_get_nonexistent_hunt(self, client):
resp = await client.get("/api/hunts/nonexistent-id")
assert resp.status_code == 404
'''
if needle in t and 'test_hunt_progress' not in t:
t=t.replace(needle, needle+'\n'+insert)
p.write_text(t,encoding='utf-8')
print('updated test_api.py')

32
_edit_test_keywords.py Normal file
View File

@@ -0,0 +1,32 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_keywords.py')
t=p.read_text(encoding='utf-8')
add='''
@pytest.mark.asyncio
async def test_quick_scan_cache_hit(client: AsyncClient):
"""Second quick scan should return cache hit metadata."""
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
tid = theme_res.json()["id"]
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
upload = await client.post("/api/datasets/upload", files=files)
ds_id = upload.json()["id"]
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert first.status_code == 200
assert first.json().get("cache_status") in ("miss", "hit")
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert second.status_code == 200
body = second.json()
assert body.get("cache_used") is True
assert body.get("cache_status") == "hit"
'''
if 'test_quick_scan_cache_hit' not in t:
t=t + add
p.write_text(t,encoding='utf-8')
print('updated test_keywords.py')

26
_edit_upload.py Normal file
View File

@@ -0,0 +1,26 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
t=p.read_text(encoding='utf-8')
# import useEffect
t=t.replace("import React, { useState, useCallback, useRef } from 'react';","import React, { useState, useCallback, useRef, useEffect } from 'react';")
# import HuntProgress type
t=t.replace("import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';","import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client';")
# add state
if 'const [huntProgress, setHuntProgress]' not in t:
t=t.replace(" const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');"," const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');\n const [huntProgress, setHuntProgress] = useState<HuntProgress | null>(null);")
# add polling effect after hunts list effect
marker=" React.useEffect(() => {\n hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});\n }, []);\n"
if marker in t and 'setInterval' not in t.split(marker,1)[1][:500]:
add='''\n useEffect(() => {\n let timer: any = null;\n let cancelled = false;\n\n const pull = async () => {\n if (!huntId) {\n if (!cancelled) setHuntProgress(null);\n return;\n }\n try {\n const p = await hunts.progress(huntId);\n if (!cancelled) setHuntProgress(p);\n } catch {\n if (!cancelled) setHuntProgress(null);\n }\n };\n\n pull();\n if (huntId) timer = setInterval(pull, 2000);\n return () => { cancelled = true; if (timer) clearInterval(timer); };\n }, [huntId, jobs.length]);\n'''
t=t.replace(marker, marker+add)
# insert master progress UI after overall summary
insert_after=''' {overallTotal > 0 && (\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mt: 2 }}>\n <Typography variant="body2" color="text.secondary">\n {overallDone + overallErr} / {overallTotal} files processed\n {overallErr > 0 && ` ({overallErr} failed)`}\n </Typography>\n <Box sx={{ flexGrow: 1 }} />\n {overallDone + overallErr === overallTotal && overallTotal > 0 && (\n <Tooltip title="Clear completed">\n <IconButton size="small" onClick={clearCompleted}><ClearIcon fontSize="small" /></IconButton>\n </Tooltip>\n )}\n </Stack>\n )}\n'''
add_block='''\n {huntId && huntProgress && (\n <Paper sx={{ p: 1.5, mt: 1.5 }}>\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>\n <Typography variant="body2" sx={{ fontWeight: 600 }}>\n Master Processing Progress\n </Typography>\n <Chip\n size="small"\n label={huntProgress.status.toUpperCase()}\n color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}\n variant="outlined"\n />\n <Box sx={{ flexGrow: 1 }} />\n <Typography variant="caption" color="text.secondary">\n {huntProgress.progress_percent.toFixed(1)}%\n </Typography>\n </Stack>\n <LinearProgress\n variant="determinate"\n value={Math.max(0, Math.min(100, huntProgress.progress_percent))}\n sx={{ height: 8, borderRadius: 4 }}\n />\n <Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>\n <Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />\n <Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />\n <Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />\n <Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />\n </Stack>\n </Paper>\n )}\n'''
if insert_after in t:
t=t.replace(insert_after, insert_after+add_block)
else:
print('warning: summary block not found')
p.write_text(t,encoding='utf-8')
print('updated FileUpload.tsx')

42
_edit_upload2.py Normal file
View File

@@ -0,0 +1,42 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
t=p.read_text(encoding='utf-8')
marker=''' {/* Per-file progress list */}
'''
add=''' {huntId && huntProgress && (
<Paper sx={{ p: 1.5, mt: 1.5 }}>
<Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>
<Typography variant="body2" sx={{ fontWeight: 600 }}>
Master Processing Progress
</Typography>
<Chip
size="small"
label={huntProgress.status.toUpperCase()}
color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}
variant="outlined"
/>
<Box sx={{ flexGrow: 1 }} />
<Typography variant="caption" color="text.secondary">
{huntProgress.progress_percent.toFixed(1)}%
</Typography>
</Stack>
<LinearProgress
variant="determinate"
value={Math.max(0, Math.min(100, huntProgress.progress_percent))}
sx={{ height: 8, borderRadius: 4 }}
/>
<Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>
<Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />
<Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />
<Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />
<Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />
</Stack>
</Paper>
)}
'''
if marker not in t:
raise SystemExit('marker not found')
t=t.replace(marker, add+marker)
p.write_text(t,encoding='utf-8')
print('inserted master progress block')

View File

@@ -0,0 +1,55 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from sqlalchemy.ext.asyncio import AsyncSession\n','from sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\n')
old=''' import asyncio
for ds_id, ds_name in ds_map.items():
last_id = 0
while True:
'''
new=''' import asyncio
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
budget_reached = False
for ds_id, ds_name in ds_map.items():
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
last_id = 0
while True:
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
'''
if old not in t:
raise SystemExit('scanner loop block not found')
t=t.replace(old,new)
old2=''' if len(rows) < BATCH_SIZE:
break
'''
new2=''' if len(rows) < BATCH_SIZE:
break
if budget_reached:
break
if budget_reached:
logger.warning(
"AUP scan row budget reached (%d rows). Returning partial results.",
result.rows_scanned,
)
'''
if old2 not in t:
raise SystemExit('scanner break block not found')
t=t.replace(old2,new2,1)
p.write_text(t,encoding='utf-8')
print('added scanner global row budget enforcement')

12
_fix_aup_dep.py Normal file
View File

@@ -0,0 +1,12 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
old=''' }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
'''
new=''' }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
'''
if old not in t:
raise SystemExit('runScan deps block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('fixed AUPScanner runScan dependency list')

7
_fix_import_datasets.py Normal file
View File

@@ -0,0 +1,7 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'from app.db.models import ProcessingTask' not in t:
t=t.replace('from app.db import get_db\n', 'from app.db import get_db\nfrom app.db.models import ProcessingTask\n')
p.write_text(t, encoding='utf-8')
print('added ProcessingTask import')

View File

@@ -0,0 +1,25 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")
'''
new=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
return {
"total_hits": 0,
"hits": [],
"themes_scanned": 0,
"keywords_scanned": 0,
"rows_scanned": 0,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
'''
if old not in t:
raise SystemExit('scope guard block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('adjusted empty scan guard to return fast empty result (200)')

View File

@@ -0,0 +1,47 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# Add label selector in toolbar before refresh button
insert_after=""" <TextField
size=\"small\"
placeholder=\"Search hosts, IPs, users\\u2026\"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
"""
label_ctrl="""
<FormControl size=\"small\" sx={{ minWidth: 150 }}>
<InputLabel id=\"label-mode-selector\">Labels</InputLabel>
<Select
labelId=\"label-mode-selector\"
value={labelMode}
label=\"Labels\"
onChange={e => setLabelMode(e.target.value as LabelMode)}
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
>
<MenuItem value=\"none\">None</MenuItem>
<MenuItem value=\"highlight\">Selected/Search</MenuItem>
<MenuItem value=\"all\">All</MenuItem>
</Select>
</FormControl>
"""
if 'label-mode-selector' not in t:
if insert_after not in t:
raise SystemExit('search block not found for label selector insertion')
t=t.replace(insert_after, insert_after+label_ctrl)
# Fix useCallback dependency for startAnimLoop
old=' }, [canvasSize]);'
new=' }, [canvasSize, labelMode]);'
if old in t:
t=t.replace(old,new,1)
p.write_text(t,encoding='utf-8')
print('inserted label selector UI and fixed callback dependency')

View File

@@ -0,0 +1,10 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
count=t.count('}, [canvasSize]);')
if count:
t=t.replace('}, [canvasSize]);','}, [canvasSize, labelMode]);')
# In case formatter created spaced variant
t=t.replace('}, [canvasSize ]);','}, [canvasSize, labelMode]);')
p.write_text(t,encoding='utf-8')
print('patched remaining canvasSize callback deps:', count)

71
_harden_aup_scope_ui.py Normal file
View File

@@ -0,0 +1,71 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
# Auto-select first hunt with datasets after load
old=''' const [tRes, hRes] = await Promise.all([
keywords.listThemes(),
hunts.list(0, 200),
]);
setThemes(tRes.themes);
setHuntList(hRes.hunts);
'''
new=''' const [tRes, hRes] = await Promise.all([
keywords.listThemes(),
hunts.list(0, 200),
]);
setThemes(tRes.themes);
setHuntList(hRes.hunts);
if (!selectedHuntId && hRes.hunts.length > 0) {
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
setSelectedHuntId(best.id);
}
'''
if old not in t:
raise SystemExit('loadData block not found')
t=t.replace(old,new)
# Guard runScan
old2=''' const runScan = useCallback(async () => {
setScanning(true);
setScanResult(null);
try {
'''
new2=''' const runScan = useCallback(async () => {
if (!selectedHuntId) {
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
return;
}
if (selectedDs.size === 0) {
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
return;
}
setScanning(true);
setScanResult(null);
try {
'''
if old2 not in t:
raise SystemExit('runScan header not found')
t=t.replace(old2,new2)
# update loadData deps
old3=''' }, [enqueueSnackbar]);
'''
new3=''' }, [enqueueSnackbar, selectedHuntId]);
'''
if old3 not in t:
raise SystemExit('loadData deps not found')
t=t.replace(old3,new3,1)
# disable button if no hunt or no datasets
old4=''' onClick={runScan} disabled={scanning}
'''
new4=''' onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
'''
if old4 not in t:
raise SystemExit('scan button props not found')
t=t.replace(old4,new4)
p.write_text(t,encoding='utf-8')
print('hardened AUPScanner to require explicit hunt/dataset scope')

View File

@@ -0,0 +1,84 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old=''' if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
'''
new=''' if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
if missing:
missing_entries: list[dict] = []
for dataset_id in missing:
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
keyword_scan_cache.put(dataset_id, partial)
missing_entries.append({"result": partial, "built_at": None})
merged = _merge_cached_results(
cached_entries + missing_entries,
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
if old not in t:
raise SystemExit('cache block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated keyword /scan to use partial cache + scan missing datasets only')

View File

@@ -0,0 +1,61 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
start=t.index(' async def _scan_datasets(')
end=t.index(' async def _scan_hunts', start)
new_func=''' async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
ds_result = await self.db.execute(ds_q)
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
if not ds_map:
return
import asyncio
for ds_id, ds_name in ds_map.items():
last_id = 0
while True:
rows_result = await self.db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds_id)
.where(DatasetRow.id > last_id)
.order_by(DatasetRow.id)
.limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
)
last_id = rows[-1].id
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
'''
out=t[:start]+new_func+t[end:]
p.write_text(out,encoding='utf-8')
print('optimized scanner _scan_datasets to keyset pagination')

36
_patch_inventory_stats.py Normal file
View File

@@ -0,0 +1,36 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
old=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
},
}
'''
new=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
},
}
'''
if old not in t:
raise SystemExit('return block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('patched inventory stats metadata')

View File

@@ -0,0 +1,10 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
needle=' "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n'
if '"row_budget_per_dataset"' not in t:
if needle not in t:
raise SystemExit('needle not found')
t=t.replace(needle, needle + ' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,\n')
p.write_text(t,encoding='utf-8')
print('inserted inventory budget stats lines')

14
_patch_network_sleep.py Normal file
View File

@@ -0,0 +1,14 @@
from pathlib import Path
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
text = p.read_text(encoding="utf-8")
anchor = " useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize]);\n"
insert = anchor + "\n const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
if "const sleep = (ms: number)" not in text and anchor in text:
text = text.replace(anchor, insert)
text = text.replace("await new Promise(r => setTimeout(r, delayMs + jitter));", "await sleep(delayMs + jitter);")
p.write_text(text, encoding="utf-8")
print("Patched sleep helper + polling awaits")

37
_patch_network_wait.py Normal file
View File

@@ -0,0 +1,37 @@
from pathlib import Path
import re
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
text = p.read_text(encoding="utf-8")
pattern = re.compile(r"const waitUntilReady = async \(\): Promise<boolean> => \{[\s\S]*?\n\s*\};", re.M)
replacement = '''const waitUntilReady = async (): Promise<boolean> => {
// Poll inventory-status with exponential backoff until 'ready' (or cancelled)
setProgress('Host inventory is being prepared in the background');
setLoading(true);
let delayMs = 1500;
const startedAt = Date.now();
for (;;) {
const jitter = Math.floor(Math.random() * 250);
await new Promise(r => setTimeout(r, delayMs + jitter));
if (cancelled) return false;
try {
const st = await network.inventoryStatus(selectedHuntId);
if (cancelled) return false;
if (st.status === 'ready') return true;
if (Date.now() - startedAt > 5 * 60 * 1000) {
setError('Host inventory build timed out. Please retry.');
return false;
}
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
// still building or none (job may not have started yet) - keep polling
} catch {
if (cancelled) return false;
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
}
}
};'''
new_text, n = pattern.subn(replacement, text, count=1)
if n != 1:
raise SystemExit(f"Failed to patch waitUntilReady, matches={n}")
p.write_text(new_text, encoding="utf-8")
print("Patched waitUntilReady")

View File

@@ -0,0 +1,26 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=25000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
'''
new=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=5000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
default=120000,
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
default=120000,
description="Max unique connection tuples retained during host inventory build",
)
'''
if old not in t:
raise SystemExit('network inventory block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated network inventory budgets in config')

View File

@@ -0,0 +1,164 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
# insert budget vars near existing counters
old=''' connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
'''
new=''' connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
'''
if old not in t:
raise SystemExit('counter block not found')
t=t.replace(old,new)
# update batch size and sampled count increments + global budget checks
old2=''' batch_size = 10000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
'''
new2=''' batch_size = 5000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
'''
if old2 not in t:
raise SystemExit('batch block not found')
t=t.replace(old2,new2)
old3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
'''
new3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
'''
if old3 not in t:
raise SystemExit('row scan block not found')
t=t.replace(old3,new3)
# cap connection map growth
old4=''' for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
connections[(host_key, rip, rport)] += 1
'''
new4=''' for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
'''
if old4 not in t:
raise SystemExit('connection block not found')
t=t.replace(old4,new4)
# sampled_dataset counter
old5=''' if sampled_dataset:
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
'''
new5=''' if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
'''
if old5 not in t:
raise SystemExit('sampled block not found')
t=t.replace(old5,new5)
# break dataset loop if global budget reached
old6=''' if len(rows) < batch_size:
break
# Post-process hosts
'''
new6=''' if len(rows) < batch_size:
break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts
'''
if old6 not in t:
raise SystemExit('post-process boundary block not found')
t=t.replace(old6,new6)
# add stats
old7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
},
}
'''
new7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
},
}
'''
if old7 not in t:
raise SystemExit('stats block not found')
t=t.replace(old7,new7)
p.write_text(t,encoding='utf-8')
print('updated host inventory with global row and connection budgets')

View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
repls={
"const LARGE_HUNT_SUBGRAPH_HOSTS = 350;":"const LARGE_HUNT_SUBGRAPH_HOSTS = 220;",
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;":"const LARGE_HUNT_SUBGRAPH_EDGES = 1200;",
"const RENDER_SIMPLIFY_NODE_THRESHOLD = 220;":"const RENDER_SIMPLIFY_NODE_THRESHOLD = 120;",
"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;":"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500;",
"const EDGE_DRAW_TARGET = 1000;":"const EDGE_DRAW_TARGET = 600;"
}
for a,b in repls.items():
if a not in t:
raise SystemExit(f'missing constant: {a}')
t=t.replace(a,b)
old=''' // Then label hit (so clicking text works too)
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
'''
new=''' // Then label hit (so clicking text works too on manageable graph sizes)
if (graph.nodes.length <= 220) {
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
}
'''
if old not in t:
raise SystemExit('label hit block not found')
t=t.replace(old,new)
old2='simulate(g, w / 2, h / 2, 60);'
if t.count(old2) < 2:
raise SystemExit('expected two simulate calls')
t=t.replace(old2,'simulate(g, w / 2, h / 2, 20);',1)
t=t.replace(old2,'simulate(g, w / 2, h / 2, 30);',1)
p.write_text(t,encoding='utf-8')
print('tightened network map rendering + load limits')

107
_perf_patch_backend.py Normal file
View File

@@ -0,0 +1,107 @@
from pathlib import Path
# config updates
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=cfg.read_text(encoding='utf-8')
anchor=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
'''
ins=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=200000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
'''
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
if anchor not in t:
raise SystemExit('config network anchor not found')
t=t.replace(anchor,ins)
cfg.write_text(t,encoding='utf-8')
# host inventory updates
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from app.db.models import Dataset, DatasetRow\n', 'from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
old=''' for ro in rows:
data = ro.data or {}
total_rows += 1
fqdn = ''
'''
new=''' for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
fqdn = ''
'''
if old not in t:
raise SystemExit('row loop anchor not found')
t=t.replace(old,new)
old2=''' last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
'''
new2=''' if sampled_dataset:
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
'''
if old2 not in t:
raise SystemExit('batch loop end anchor not found')
t=t.replace(old2,new2)
old3=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
},
}
'''
new3=''' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0
return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": sampled,
},
}
'''
if old3 not in t:
raise SystemExit('return stats anchor not found')
t=t.replace(old3,new3)
p.write_text(t,encoding='utf-8')
print('patched config + host inventory row budget')

38
_perf_patch_backend2.py Normal file
View File

@@ -0,0 +1,38 @@
from pathlib import Path
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=cfg.read_text(encoding='utf-8')
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
t=t.replace(
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
''',
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=200000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
''')
cfg.write_text(t,encoding='utf-8')
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from app.db.models import Dataset, DatasetRow\n','from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n',
' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
t=t.replace(' for ro in rows:\n data = ro.data or {}\n total_rows += 1\n\n',
' for ro in rows:\n if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:\n sampled_dataset = True\n break\n\n data = ro.data or {}\n total_rows += 1\n rows_scanned_this_dataset += 1\n\n')
t=t.replace(' last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n',
' if sampled_dataset:\n logger.info(\n "Host inventory row budget reached for dataset %s (%d rows)",\n ds.id,\n rows_scanned_this_dataset,\n )\n break\n\n last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n')
t=t.replace(' return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n },\n }\n',
' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0\n\n return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": sampled,\n },\n }\n')
p.write_text(t,encoding='utf-8')
print('patched backend inventory performance settings')

220
_perf_patch_networkmap.py Normal file
View File

@@ -0,0 +1,220 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# constants
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
t=t.replace(
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n",
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n")
# drawBackground signature
t_old='''function drawBackground(
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
) {
'''
if t_old in t:
t=t.replace(t_old,
'''function drawBackground(
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
simplify: boolean,
) {
''')
# skip grid when simplify
if 'if (!simplify) {' not in t:
t=t.replace(
''' ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
''',
''' if (!simplify) {
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
}
''')
# drawEdges signature
t=t.replace('''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
) {
for (const e of graph.edges) {
''',
'''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
simplify: boolean,
) {
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
const e = graph.edges[ei];
''')
# simplify edge path
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
# reduce glow when simplify
t=t.replace(''' ctx.save();
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
ctx.stroke(); ctx.restore();
''',
''' if (!simplify) {
ctx.save();
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
ctx.stroke(); ctx.restore();
}
''')
# drawLabels signature and early return
t=t.replace('''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
) {
''',
'''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
''')
if 'if (simplify && !search && !hovered && !selected) {' not in t:
t=t.replace(' const dimmed = search.length > 0;\n',
' const dimmed = search.length > 0;\n if (simplify && !search && !hovered && !selected) {\n return;\n }\n')
# drawGraph adapt
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
ctx.restore();
''',
''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
drawBackground(ctx, w, h, vp, dpr, simplify);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
ctx.restore();
''')
# hover RAF ref
if 'const hoverRafRef = useRef<number>(0);' not in t:
t=t.replace(' const graphRef = useRef<Graph | null>(null);\n', ' const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
# throttle hover hit test on mousemove
old_mm=''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
setHovered(node?.id ?? null);
}, [graph, redraw, startAnimLoop]);
'''
new_mm=''' cancelAnimationFrame(hoverRafRef.current);
const clientX = e.clientX;
const clientY = e.clientY;
hoverRafRef.current = requestAnimationFrame(() => {
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
});
}, [graph, redraw, startAnimLoop]);
'''
if old_mm in t:
t=t.replace(old_mm,new_mm)
# cleanup hover raf on unmount in existing animation cleanup effect
if 'cancelAnimationFrame(hoverRafRef.current);' not in t:
t=t.replace(''' useEffect(() => {
if (graph) startAnimLoop();
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
}, [graph, startAnimLoop]);
''',
''' useEffect(() => {
if (graph) startAnimLoop();
return () => {
cancelAnimationFrame(animFrameRef.current);
cancelAnimationFrame(hoverRafRef.current);
isAnimatingRef.current = false;
};
}, [graph, startAnimLoop]);
''')
# connectedNodes optimization map
if 'const nodeById = useMemo(() => {' not in t:
t=t.replace(''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const connectedNodes = useMemo(() => {
''',
''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const nodeById = useMemo(() => {
const m = new Map<string, GNode>();
if (!graph) return m;
for (const n of graph.nodes) m.set(n.id, n);
return m;
}, [graph]);
const connectedNodes = useMemo(() => {
''')
t=t.replace(''' const n = graph.nodes.find(x => x.id === e.target);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
} else if (e.target === selectedNode.id) {
const n = graph.nodes.find(x => x.id === e.source);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
''',
''' const n = nodeById.get(e.target);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
} else if (e.target === selectedNode.id) {
const n = nodeById.get(e.source);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
''')
t=t.replace(' }, [selectedNode, graph]);\n', ' }, [selectedNode, graph, nodeById]);\n')
p.write_text(t,encoding='utf-8')
print('patched NetworkMap adaptive render + hover throttle')

153
_perf_patch_networkmap2.py Normal file
View File

@@ -0,0 +1,153 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
t=t.replace('const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n', 'const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n')
t=t.replace('function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n) {', 'function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n simplify: boolean,\n) {')
t=t.replace(''' ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
''',''' if (!simplify) {
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
}
''')
t=t.replace('''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
) {
for (const e of graph.edges) {
''','''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
simplify: boolean,
) {
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
const e = graph.edges[ei];
''')
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
t=t.replace('''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
) {
const dimmed = search.length > 0;
''','''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
const dimmed = search.length > 0;
if (simplify && !search && !hovered && !selected) {
return;
}
''')
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
ctx.restore();
''',''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
drawBackground(ctx, w, h, vp, dpr, simplify);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
ctx.restore();
''')
if 'const hoverRafRef = useRef<number>(0);' not in t:
t=t.replace('const graphRef = useRef<Graph | null>(null);\n', 'const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
t=t.replace(''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
setHovered(node?.id ?? null);
}, [graph, redraw, startAnimLoop]);
''',''' cancelAnimationFrame(hoverRafRef.current);
const clientX = e.clientX;
const clientY = e.clientY;
hoverRafRef.current = requestAnimationFrame(() => {
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
});
}, [graph, redraw, startAnimLoop]);
''')
t=t.replace(''' useEffect(() => {
if (graph) startAnimLoop();
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
}, [graph, startAnimLoop]);
''',''' useEffect(() => {
if (graph) startAnimLoop();
return () => {
cancelAnimationFrame(animFrameRef.current);
cancelAnimationFrame(hoverRafRef.current);
isAnimatingRef.current = false;
};
}, [graph, startAnimLoop]);
''')
if 'const nodeById = useMemo(() => {' not in t:
t=t.replace(''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const connectedNodes = useMemo(() => {
''',''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const nodeById = useMemo(() => {
const m = new Map<string, GNode>();
if (!graph) return m;
for (const n of graph.nodes) m.set(n.id, n);
return m;
}, [graph]);
const connectedNodes = useMemo(() => {
''')
t=t.replace('const n = graph.nodes.find(x => x.id === e.target);','const n = nodeById.get(e.target);')
t=t.replace('const n = graph.nodes.find(x => x.id === e.source);','const n = nodeById.get(e.source);')
t=t.replace(' }, [selectedNode, graph]);',' }, [selectedNode, graph, nodeById]);')
p.write_text(t,encoding='utf-8')
print('patched NetworkMap performance')

View File

@@ -0,0 +1,227 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
start=t.index('async def build_host_inventory(')
# find end of function by locating '\n\n' before EOF after ' }\n'
end=t.index('\n\n', start)
# need proper end: first double newline after function may occur in docstring? compute by searching for '\n\n' after ' }\n' near end
ret_idx=t.rfind(' }')
# safer locate end as last occurrence of '\n }\n' after start, then function ends next newline
end=t.find('\n\n', ret_idx)
if end==-1:
end=len(t)
new_func='''async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
"""Build a deduplicated host inventory from all datasets in a hunt.
Returns dict with 'hosts', 'connections', and 'stats'.
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
"""
ds_result = await db.execute(
select(Dataset).where(Dataset.hunt_id == hunt_id)
)
all_datasets = ds_result.scalars().all()
if not all_datasets:
return {"hosts": [], "connections": [], "stats": {
"total_hosts": 0, "total_datasets_scanned": 0,
"total_rows_scanned": 0,
}}
hosts: dict[str, dict] = {} # fqdn -> host record
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
for ds in all_datasets:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
cols = _identify_columns(ds)
if not cols['fqdn'] and not cols['host_id']:
continue
ds_with_hosts += 1
batch_size = 5000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index)
.limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
fqdn = ''
for c in cols['fqdn']:
fqdn = _clean(data.get(c))
if fqdn:
break
client_id = ''
for c in cols['host_id']:
client_id = _clean(data.get(c))
if client_id:
break
if not fqdn and not client_id:
continue
host_key = fqdn or client_id
if host_key not in hosts:
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
hosts[host_key] = {
'id': host_key,
'hostname': short or client_id,
'fqdn': fqdn,
'client_id': client_id,
'ips': set(),
'os': '',
'users': set(),
'datasets': set(),
'row_count': 0,
}
h = hosts[host_key]
h['datasets'].add(ds.name)
h['row_count'] += 1
if client_id and not h['client_id']:
h['client_id'] = client_id
for c in cols['username']:
u = _extract_username(_clean(data.get(c)))
if u:
h['users'].add(u)
for c in cols['local_ip']:
ip = _clean(data.get(c))
if _is_valid_ip(ip):
h['ips'].add(ip)
ip_to_host[ip] = host_key
for c in cols['os']:
ov = _clean(data.get(c))
if ov and not h['os']:
h['os'] = ov
for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory sampling for dataset %s (%d rows scanned)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts
for h in hosts.values():
if not h['os'] and h['fqdn']:
h['os'] = _infer_os(h['fqdn'])
h['ips'] = sorted(h['ips'])
h['users'] = sorted(h['users'])
h['datasets'] = sorted(h['datasets'])
# Build connections, resolving IPs to host keys
conn_list = []
seen = set()
for (src, dst_ip, dst_port), cnt in connections.items():
if dst_ip in _IGNORE_IPS:
continue
dst_host = ip_to_host.get(dst_ip, '')
if dst_host == src:
continue
key = tuple(sorted([src, dst_host or dst_ip]))
if key in seen:
continue
seen.add(key)
conn_list.append({
'source': src,
'target': dst_host or dst_ip,
'target_ip': dst_ip,
'port': dst_port,
'count': cnt,
})
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
},
}
'''
out=t[:start]+new_func+t[end:]
p.write_text(out,encoding='utf-8')
print('replaced build_host_inventory with hard-budget fast mode')

View File

@@ -0,0 +1,78 @@
"""add playbooks, playbook_steps, saved_searches tables
Revision ID: b2c3d4e5f6a7
Revises: a1b2c3d4e5f6
Create Date: 2026-02-21 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "b2c3d4e5f6a7"
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add display_name to users table
with op.batch_alter_table("users") as batch_op:
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
# Create playbooks table
op.create_table(
"playbooks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("is_template", sa.Boolean(), server_default="0"),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("status", sa.String(20), server_default="active"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create playbook_steps table
op.create_table(
"playbook_steps",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
sa.Column("order_index", sa.Integer(), nullable=False),
sa.Column("title", sa.String(256), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("step_type", sa.String(32), server_default="manual"),
sa.Column("target_route", sa.String(256), nullable=True),
sa.Column("is_completed", sa.Boolean(), server_default="0"),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("notes", sa.Text(), nullable=True),
)
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
# Create saved_searches table
op.create_table(
"saved_searches",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("search_type", sa.String(32), nullable=False),
sa.Column("query_params", sa.JSON(), nullable=False),
sa.Column("threshold", sa.Float(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_result_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
def downgrade() -> None:
op.drop_table("saved_searches")
op.drop_table("playbook_steps")
op.drop_table("playbooks")
with op.batch_alter_table("users") as batch_op:
batch_op.drop_column("display_name")

View File

@@ -0,0 +1,48 @@
"""add processing_tasks table
Revision ID: c3d4e5f6a7b8
Revises: b2c3d4e5f6a7
Create Date: 2026-02-22 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "c3d4e5f6a7b8"
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"processing_tasks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
sa.Column("job_id", sa.String(64), nullable=True),
sa.Column("stage", sa.String(64), nullable=False),
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("message", sa.Text(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
def downgrade() -> None:
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
op.drop_table("processing_tasks")

View File

@@ -1,16 +1,18 @@
"""Analyst-assist agent module for ThreatHunt. """Analyst-assist agent module for ThreatHunt.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses. Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only and do not execute actions or modify data. Agents are advisory only and do not execute actions or modify data.
""" """
from .core import ThreatHuntAgent from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
__all__ = [ __all__ = [
"ThreatHuntAgent", "ThreatHuntAgent",
"LLMProvider", "AgentContext",
"LocalProvider", "AgentResponse",
"NetworkedProvider", "Perspective",
"OnlineProvider", "OllamaProvider",
"OpenWebUIProvider",
"EmbeddingProvider",
] ]

View File

@@ -1,208 +0,0 @@
"""Core ThreatHunt analyst-assist agent.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only - no execution, no alerts, no data modifications.
"""
import logging
from typing import Optional
from pydantic import BaseModel, Field
from .providers import LLMProvider, get_provider
logger = logging.getLogger(__name__)
class AgentContext(BaseModel):
"""Context for agent guidance requests."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
host_identifier: Optional[str] = Field(
None, description="Host name, IP, or identifier"
)
data_summary: Optional[str] = Field(
None, description="Brief description of uploaded data"
)
conversation_history: Optional[list[dict]] = Field(
default_factory=list, description="Previous messages in conversation"
)
class AgentResponse(BaseModel):
"""Response from analyst-assist agent."""
guidance: str = Field(..., description="Advisory guidance for analyst")
confidence: float = Field(
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
)
suggested_pivots: list[str] = Field(
default_factory=list, description="Suggested analytical directions"
)
suggested_filters: list[str] = Field(
default_factory=list, description="Suggested data filters or queries"
)
caveats: Optional[str] = Field(
None, description="Assumptions, limitations, or caveats"
)
reasoning: Optional[str] = Field(
None, description="Explanation of how guidance was generated"
)
class ThreatHuntAgent:
"""Analyst-assist agent for ThreatHunt.
Provides guidance on:
- Interpreting CSV artifact data
- Suggesting analytical pivots and filters
- Forming and testing hypotheses
Policy:
- Advisory guidance only (no execution)
- No database or schema changes
- No alert escalation
- Transparent reasoning
"""
def __init__(self, provider: Optional[LLMProvider] = None):
"""Initialize agent with LLM provider.
Args:
provider: LLM provider instance. If None, uses get_provider() with auto mode.
"""
if provider is None:
try:
provider = get_provider("auto")
except RuntimeError as e:
logger.warning(f"Could not initialize default provider: {e}")
provider = None
self.provider = provider
self.system_prompt = self._build_system_prompt()
def _build_system_prompt(self) -> str:
"""Build the system prompt that governs agent behavior."""
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
Your role:
- Interpret and explain CSV artifact data from Velociraptor
- Suggest analytical pivots, filters, and hypotheses
- Highlight anomalies, patterns, or points of interest
- Guide analysts without replacing their judgment
Your constraints:
- You ONLY provide guidance and suggestions
- You do NOT execute actions or tools
- You do NOT modify data or escalate alerts
- You do NOT make autonomous decisions
- You ONLY analyze data presented to you
- You explain your reasoning transparently
- You acknowledge limitations and assumptions
- You suggest next investigative steps
When responding:
1. Start with a clear, direct answer to the query
2. Explain your reasoning based on the data context provided
3. Suggest 2-4 analytical pivots the analyst might explore
4. Suggest 2-4 data filters or queries that might be useful
5. Include relevant caveats or assumptions
6. Be honest about what you cannot determine from the data
Remember: The analyst is the decision-maker. You are an assistant."""
async def assist(self, context: AgentContext) -> AgentResponse:
"""Provide guidance on artifact data and analysis.
Args:
context: Request context including query and data context.
Returns:
Guidance response with suggestions and reasoning.
Raises:
RuntimeError: If no provider is available.
"""
if not self.provider:
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
# Build prompt with context
prompt = self._build_prompt(context)
try:
# Get guidance from LLM provider
guidance = await self.provider.generate(prompt, max_tokens=1024)
# Parse response into structured format
response = self._parse_response(guidance, context)
logger.info(
f"Agent assisted with query: {context.query[:50]}... "
f"(dataset: {context.dataset_name})"
)
return response
except Exception as e:
logger.error(f"Error generating guidance: {e}")
raise
def _build_prompt(self, context: AgentContext) -> str:
"""Build the prompt for the LLM."""
prompt_parts = [
f"Analyst query: {context.query}",
]
if context.dataset_name:
prompt_parts.append(f"Dataset: {context.dataset_name}")
if context.artifact_type:
prompt_parts.append(f"Artifact type: {context.artifact_type}")
if context.host_identifier:
prompt_parts.append(f"Host: {context.host_identifier}")
if context.data_summary:
prompt_parts.append(f"Data summary: {context.data_summary}")
if context.conversation_history:
prompt_parts.append("\nConversation history:")
for msg in context.conversation_history[-5:]: # Last 5 messages for context
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
return "\n".join(prompt_parts)
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
"""Parse LLM response into structured format.
Note: This is a simplified parser. In production, use structured output
from the LLM (JSON mode, function calling, etc.) for better reliability.
"""
# For now, return a structured response based on the raw guidance
# In production, parse JSON or use structured output from LLM
return AgentResponse(
guidance=response_text,
confidence=0.8, # Placeholder
suggested_pivots=[
"Analyze temporal patterns",
"Cross-reference with known indicators",
"Examine outliers in the dataset",
"Compare with baseline behavior",
],
suggested_filters=[
"Filter by high-risk indicators",
"Sort by timestamp for timeline analysis",
"Group by host or user",
"Filter by anomaly score",
],
caveats="Guidance is based on available data context. "
"Analysts should verify findings with additional sources.",
reasoning="Analysis generated based on artifact data patterns and analyst query.",
)

View File

@@ -1,190 +0,0 @@
"""Pluggable LLM provider interface for analyst-assist agents.
Supports three provider types:
- Local: On-device or on-prem models
- Networked: Shared internal inference services
- Online: External hosted APIs
"""
import os
from abc import ABC, abstractmethod
from typing import Optional
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
@abstractmethod
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate a response from the LLM.
Args:
prompt: The input prompt
max_tokens: Maximum tokens in response
Returns:
Generated text response
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if provider backend is available."""
pass
class LocalProvider(LLMProvider):
"""Local LLM provider (on-device or on-prem models)."""
def __init__(self, model_path: Optional[str] = None):
"""Initialize local provider.
Args:
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
"""
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
self.model = None
def is_available(self) -> bool:
"""Check if local model is available."""
if not self.model_path:
return False
# In production, would verify model file exists and can be loaded
return os.path.exists(str(self.model_path))
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using local model.
Note: This is a placeholder. In production, integrate with:
- llama-cpp-python for GGML models
- Ollama API
- vLLM
- Other local inference engines
"""
if not self.is_available():
raise RuntimeError("Local model not available")
# Placeholder implementation
return f"[Local model response to: {prompt[:50]}...]"
class NetworkedProvider(LLMProvider):
"""Networked LLM provider (shared internal inference services)."""
def __init__(
self,
api_endpoint: Optional[str] = None,
api_key: Optional[str] = None,
model_name: str = "default",
):
"""Initialize networked provider.
Args:
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
model_name: Model name/ID on the service.
"""
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
self.model_name = model_name
def is_available(self) -> bool:
"""Check if networked service is available."""
return bool(self.api_endpoint)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using networked service.
Note: This is a placeholder. In production, integrate with:
- Internal inference service API
- LLM inference container cluster
- Enterprise inference gateway
"""
if not self.is_available():
raise RuntimeError("Networked service not available")
# Placeholder implementation
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
class OnlineProvider(LLMProvider):
"""Online LLM provider (external hosted APIs)."""
def __init__(
self,
api_provider: str = "openai",
api_key: Optional[str] = None,
model_name: Optional[str] = None,
):
"""Initialize online provider.
Args:
api_provider: Provider name (openai, anthropic, google, etc.)
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
"""
self.api_provider = api_provider
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
self.model_name = model_name or os.getenv(
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
)
def is_available(self) -> bool:
"""Check if online API is available."""
return bool(self.api_key)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using online API.
Note: This is a placeholder. In production, integrate with:
- OpenAI API (GPT-3.5, GPT-4, etc.)
- Anthropic Claude API
- Google Gemini API
- Other hosted LLM services
"""
if not self.is_available():
raise RuntimeError("Online API not available or API key not set")
# Placeholder implementation
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
def get_provider(provider_type: str = "auto") -> LLMProvider:
"""Get an LLM provider based on configuration.
Args:
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
'auto' attempts to use the first available provider in order:
local -> networked -> online.
Returns:
Configured LLM provider instance.
Raises:
RuntimeError: If no provider is available.
"""
# Explicit provider selection
if provider_type == "local":
provider = LocalProvider()
elif provider_type == "networked":
provider = NetworkedProvider()
elif provider_type == "online":
provider = OnlineProvider()
elif provider_type == "auto":
# Try providers in order of preference
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
provider = Provider()
if provider.is_available():
return provider
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
if not provider.is_available():
raise RuntimeError(f"{provider_type} provider not available")
return provider

View File

@@ -1,170 +0,0 @@
"""API routes for analyst-assist agent."""
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
from app.agents.config import AgentConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/agent", tags=["agent"])
# Global agent instance (lazy-loaded)
_agent: ThreatHuntAgent | None = None
def get_agent() -> ThreatHuntAgent:
"""Get or create the agent instance."""
global _agent
if _agent is None:
if not AgentConfig.is_agent_enabled():
raise HTTPException(
status_code=503,
detail="Analyst-assist agent is not configured. "
"Please configure an LLM provider.",
)
_agent = ThreatHuntAgent()
return _agent
class AssistRequest(BaseModel):
"""Request for agent assistance."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: str | None = Field(
None, description="Name of CSV dataset being analyzed"
)
artifact_type: str | None = Field(
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
)
host_identifier: str | None = Field(
None, description="Host name, IP address, or identifier"
)
data_summary: str | None = Field(
None, description="Brief summary or context about the uploaded data"
)
conversation_history: list[dict] | None = Field(
None, description="Previous messages for context"
)
class AssistResponse(BaseModel):
"""Response with agent guidance."""
guidance: str
confidence: float
suggested_pivots: list[str]
suggested_filters: list[str]
caveats: str | None = None
reasoning: str | None = None
@router.post(
"/assist",
response_model=AssistResponse,
summary="Get analyst-assist guidance",
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
"Agent provides advisory guidance only - no execution.",
)
async def agent_assist(request: AssistRequest) -> AssistResponse:
"""Provide analyst-assist guidance on artifact data.
The agent will:
- Explain and interpret the provided data context
- Suggest analytical pivots the analyst might explore
- Suggest data filters or queries that might be useful
- Highlight assumptions, limitations, and caveats
The agent will NOT:
- Execute any tools or actions
- Escalate findings to alerts
- Modify any data or schema
- Make autonomous decisions
Args:
request: Assistance request with query and context
Returns:
Guidance response with suggestions and reasoning
Raises:
HTTPException: If agent is not configured (503) or request fails
"""
try:
agent = get_agent()
# Build context
context = AgentContext(
query=request.query,
dataset_name=request.dataset_name,
artifact_type=request.artifact_type,
host_identifier=request.host_identifier,
data_summary=request.data_summary,
conversation_history=request.conversation_history or [],
)
# Get guidance
response = await agent.assist(context)
logger.info(
f"Agent assisted analyst with query: {request.query[:50]}... "
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
)
return AssistResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
)
except RuntimeError as e:
logger.error(f"Agent error: {e}")
raise HTTPException(
status_code=503,
detail=f"Agent unavailable: {str(e)}",
)
except Exception as e:
logger.exception(f"Unexpected error in agent_assist: {e}")
raise HTTPException(
status_code=500,
detail="Error generating guidance. Please try again.",
)
@router.get(
"/health",
summary="Check agent health",
description="Check if agent is configured and ready to assist.",
)
async def agent_health() -> dict:
"""Check agent availability and configuration.
Returns:
Health status with configuration details
"""
try:
agent = get_agent()
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
return {
"status": "healthy",
"provider": provider_type,
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
}
except HTTPException:
return {
"status": "unavailable",
"reason": "No LLM provider configured",
"configured_providers": {
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
"online": bool(AgentConfig.ONLINE_API_KEY),
},
}

View File

@@ -1,4 +1,4 @@
"""API routes for analyst-assist agent v2. """API routes for analyst-assist agent v2.
Supports quick, deep, and debate modes with streaming. Supports quick, deep, and debate modes with streaming.
Conversations are persisted to the database. Conversations are persisted to the database.
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
import json import json
import logging import logging
import re
import time
from collections import Counter
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
from app.db import get_db from app.db import get_db
from app.db.models import Conversation, Message from app.db.models import Conversation, Message, Dataset, KeywordTheme
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from app.agents.providers_v2 import check_all_nodes from app.agents.providers_v2 import check_all_nodes
from app.agents.registry import registry from app.agents.registry import registry
from app.services.sans_rag import sans_rag from app.services.sans_rag import sans_rag
from app.services.scanner import KeywordScanner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
return _agent return _agent
# ── Request / Response models ───────────────────────────────────────── # Request / Response models
class AssistRequest(BaseModel): class AssistRequest(BaseModel):
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
model_override: str | None = None model_override: str | None = None
conversation_id: str | None = Field(None, description="Persist messages to this conversation") conversation_id: str | None = Field(None, description="Persist messages to this conversation")
hunt_id: str | None = None hunt_id: str | None = None
execution_preference: str = Field(default="auto", description="auto | force | off")
learning_mode: bool = False
class AssistResponseModel(BaseModel): class AssistResponseModel(BaseModel):
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
node_used: str = "" node_used: str = ""
latency_ms: int = 0 latency_ms: int = 0
perspectives: list[dict] | None = None perspectives: list[dict] | None = None
execution: dict | None = None
conversation_id: str | None = None conversation_id: str | None = None
# ── Routes ──────────────────────────────────────────────────────────── POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
POLICY_QUERY_TERMS = {
"policy", "violating", "violation", "browser history", "web history",
"domain", "domains", "adult", "gambling", "piracy", "aup",
}
WEB_DATASET_HINTS = {
"web", "history", "browser", "url", "visited_url", "domain", "title",
}
def _is_policy_domain_query(query: str) -> bool:
q = (query or "").lower()
if not q:
return False
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
def _should_execute_policy_scan(request: AssistRequest) -> bool:
pref = (request.execution_preference or "auto").strip().lower()
if pref == "off":
return False
if pref == "force":
return True
return _is_policy_domain_query(request.query)
def _extract_domain(value: str | None) -> str | None:
if not value:
return None
text = value.strip()
if not text:
return None
try:
parsed = urlparse(text)
if parsed.netloc:
return parsed.netloc.lower()
except Exception:
pass
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
return m.group(0).lower() if m else None
def _dataset_score(ds: Dataset) -> int:
score = 0
name = (ds.name or "").lower()
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
for h in WEB_DATASET_HINTS:
if h in name:
score += 2
if h in cols_l:
score += 3
if h in norm_vals_l:
score += 3
if "visited_url" in cols_l or "url" in cols_l:
score += 8
if "user" in cols_l or "username" in cols_l:
score += 2
if "clientid" in cols_l or "fqdn" in cols_l:
score += 2
if (ds.row_count or 0) > 0:
score += 1
return score
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
scanner = KeywordScanner(db)
theme_result = await db.execute(
select(KeywordTheme).where(
KeywordTheme.enabled == True, # noqa: E712
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
)
)
themes = list(theme_result.scalars().all())
theme_ids = [t.id for t in themes]
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
if request.hunt_id:
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
ds_result = await db.execute(ds_query)
candidates = list(ds_result.scalars().all())
if request.dataset_name:
needle = request.dataset_name.lower().strip()
candidates = [d for d in candidates if needle in (d.name or "").lower()]
scored = sorted(
((d, _dataset_score(d)) for d in candidates),
key=lambda x: x[1],
reverse=True,
)
selected = [d for d, s in scored if s > 0][:8]
dataset_ids = [d.id for d in selected]
if not dataset_ids:
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": 0,
"dataset_names": [],
"total_hits": 0,
"policy_hits": 0,
"top_user_hosts": [],
"top_domains": [],
"sample_hits": [],
"note": "No suitable browser/web-history datasets found in current scope.",
}
result = await scanner.scan(
dataset_ids=dataset_ids,
theme_ids=theme_ids or None,
scan_hunts=False,
scan_annotations=False,
scan_messages=False,
)
hits = result.get("hits", [])
user_host_counter = Counter()
domain_counter = Counter()
for h in hits:
user = h.get("username") or "(unknown-user)"
host = h.get("hostname") or "(unknown-host)"
user_host_counter[f"{user}|{host}"] += 1
dom = _extract_domain(h.get("matched_value"))
if dom:
domain_counter[dom] += 1
top_user_hosts = [
{"user_host": k, "count": v}
for k, v in user_host_counter.most_common(10)
]
top_domains = [
{"domain": k, "count": v}
for k, v in domain_counter.most_common(10)
]
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": len(dataset_ids),
"dataset_names": [d.name for d in selected],
"total_hits": int(result.get("total_hits", 0)),
"policy_hits": int(result.get("total_hits", 0)),
"rows_scanned": int(result.get("rows_scanned", 0)),
"top_user_hosts": top_user_hosts,
"top_domains": top_domains,
"sample_hits": hits[:20],
}
# Routes
@router.post( @router.post(
@@ -84,6 +252,76 @@ async def agent_assist(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> AssistResponseModel: ) -> AssistResponseModel:
try: try:
# Deterministic execution mode for policy-domain investigations.
if _should_execute_policy_scan(request):
t0 = time.monotonic()
exec_payload = await _run_policy_domain_execution(request, db)
latency_ms = int((time.monotonic() - t0) * 1000)
policy_hits = exec_payload.get("policy_hits", 0)
datasets_scanned = exec_payload.get("datasets_scanned", 0)
if policy_hits > 0:
guidance = (
f"Policy-violation scan complete: {policy_hits} hits across "
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
f"in execution results for triage."
)
confidence = 0.95
caveats = "Keyword-based matching can include false positives; validate with full URL context."
else:
guidance = (
f"No policy-violation hits found in current scope "
f"({datasets_scanned} dataset(s) scanned)."
)
confidence = 0.9
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
response = AssistResponseModel(
guidance=guidance,
confidence=confidence,
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
suggested_filters=[
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
"username != null",
"hostname != null",
],
caveats=caveats,
reasoning=(
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
if _is_policy_domain_query(request.query)
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
),
sans_references=["SANS FOR508", "SANS SEC504"],
model_used="execution:keyword_scanner",
node_used="local",
latency_ms=latency_ms,
execution=exec_payload,
)
conv_id = request.conversation_id
if conv_id or request.hunt_id:
conv_id = await _persist_conversation(
db,
conv_id,
request,
AgentResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
sans_references=response.sans_references,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
),
)
response.conversation_id = conv_id
return response
agent = get_agent() agent = get_agent()
context = AgentContext( context = AgentContext(
query=request.query, query=request.query,
@@ -97,6 +335,7 @@ async def agent_assist(
enrichment_summary=request.enrichment_summary, enrichment_summary=request.enrichment_summary,
mode=request.mode, mode=request.mode,
model_override=request.model_override, model_override=request.model_override,
learning_mode=request.learning_mode,
) )
response = await agent.assist(context) response = await agent.assist(context)
@@ -129,6 +368,7 @@ async def agent_assist(
} }
for p in response.perspectives for p in response.perspectives
] if response.perspectives else None, ] if response.perspectives else None,
execution=None,
conversation_id=conv_id, conversation_id=conv_id,
) )
@@ -208,7 +448,7 @@ async def list_models():
} }
# ── Conversation persistence ────────────────────────────────────────── # Conversation persistence
async def _persist_conversation( async def _persist_conversation(
@@ -263,3 +503,4 @@ async def _persist_conversation(
await db.flush() await db.flush()
return conv.id return conv.id

View File

@@ -290,6 +290,47 @@ async def get_knowledge_graph(
hunt_id: str | None = Query(None), hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
<<<<<<< HEAD
"""Submit a new job to the queue.
Job types: triage, host_profile, report, anomaly, query
Params vary by type (e.g., dataset_id, hunt_id, question, mode).
"""
from app.services.job_queue import job_queue, JobType
try:
jt = JobType(job_type)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
)
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
job = job_queue.submit(jt, **params)
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
# --- Load balancer status ---
@router.get("/lb/status")
async def lb_status():
"""Get load balancer status for both nodes."""
from app.services.load_balancer import lb
return lb.get_status()
@router.post("/lb/check")
async def lb_health_check():
"""Force a health check of both nodes."""
from app.services.load_balancer import lb
await lb.check_health()
return lb.get_status()
=======
if not dataset_id and not hunt_id: if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id") raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id) return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -1,4 +1,4 @@
"""API routes for authentication register, login, refresh, profile.""" """API routes for authentication — register, login, refresh, profile."""
import logging import logging
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
# ── Request / Response models ───────────────────────────────────────── # ── Request / Response models ─────────────────────────────────────────
class RegisterRequest(BaseModel): class RegisterRequest(BaseModel):
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
tokens: TokenPair tokens: TokenPair
# ── Routes ──────────────────────────────────────────────────────────── # ── Routes ────────────────────────────────────────────────────────────
@router.post( @router.post(
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
user = User( user = User(
username=body.username, username=body.username,
email=body.email, email=body.email,
password_hash=hash_password(body.password), hashed_password=hash_password(body.password),
display_name=body.display_name or body.username, display_name=body.display_name or body.username,
role="analyst", # Default role role="analyst", # Default role
) )
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.username == body.username)) result = await db.execute(select(User).where(User.username == body.username))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or not user.password_hash: if not user or not user.hashed_password:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password", detail="Invalid username or password",
) )
if not verify_password(body.password, user.password_hash): if not verify_password(body.password, user.hashed_password):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password", detail="Invalid username or password",
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
if token_data.type != "refresh": if token_data.type != "refresh":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type use refresh token", detail="Invalid token type — use refresh token",
) )
result = await db.execute(select(User).where(User.id == token_data.sub)) result = await db.execute(select(User).where(User.id == token_data.sub))
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
is_active=user.is_active, is_active=user.is_active,
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at), created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
) )

View File

@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
from app.db import get_db from app.db import get_db
from app.db.models import ProcessingTask
from app.db.repositories.datasets import DatasetRepository from app.db.repositories.datasets import DatasetRepository
from app.services.csv_parser import parse_csv_bytes, infer_column_types from app.services.csv_parser import parse_csv_bytes, infer_column_types
from app.services.normalizer import ( from app.services.normalizer import (
@@ -18,15 +19,20 @@ from app.services.normalizer import (
detect_ioc_columns, detect_ioc_columns,
detect_time_range, detect_time_range,
) )
from app.services.artifact_classifier import classify_artifact, get_artifact_category
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from app.services.job_queue import job_queue, JobType
from app.services.host_inventory import inventory_cache
from app.services.scanner import keyword_scan_cache
router = APIRouter(prefix="/api/datasets", tags=["datasets"]) router = APIRouter(prefix="/api/datasets", tags=["datasets"])
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"} ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
# ── Response models ─────────────────────────────────────────────────── # -- Response models --
class DatasetSummary(BaseModel): class DatasetSummary(BaseModel):
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
delimiter: str | None = None delimiter: str | None = None
time_range_start: str | None = None time_range_start: str | None = None
time_range_end: str | None = None time_range_end: str | None = None
artifact_type: str | None = None
processing_status: str | None = None
hunt_id: str | None = None hunt_id: str | None = None
created_at: str created_at: str
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
column_types: dict column_types: dict
normalized_columns: dict normalized_columns: dict
ioc_columns: dict ioc_columns: dict
artifact_type: str | None = None
processing_status: str
jobs_queued: list[str]
message: str message: str
# ── Routes ──────────────────────────────────────────────────────────── # -- Routes --
@router.post( @router.post(
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
response_model=UploadResponse, response_model=UploadResponse,
summary="Upload a CSV dataset", summary="Upload a CSV dataset",
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, " description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
"IOCs auto-detected, and rows stored in the database.", "IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
) )
async def upload_dataset( async def upload_dataset(
file: UploadFile = File(...), file: UploadFile = File(...),
@@ -87,7 +98,7 @@ async def upload_dataset(
hunt_id: str | None = Query(None, description="Hunt ID to associate with"), hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Upload and parse a CSV dataset.""" """Upload and parse a CSV dataset, then trigger full processing pipeline."""
# Validate file # Validate file
if not file.filename: if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided") raise HTTPException(status_code=400, detail="No filename provided")
@@ -136,7 +147,12 @@ async def upload_dataset(
# Detect time range # Detect time range
time_start, time_end = detect_time_range(rows, column_mapping) time_start, time_end = detect_time_range(rows, column_mapping)
# Store in DB # Classify artifact type from column headers
artifact_type = classify_artifact(columns)
artifact_category = get_artifact_category(artifact_type)
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
# Store in DB with processing_status = "processing"
repo = DatasetRepository(db) repo = DatasetRepository(db)
dataset = await repo.create_dataset( dataset = await repo.create_dataset(
name=name or Path(file.filename).stem, name=name or Path(file.filename).stem,
@@ -152,6 +168,8 @@ async def upload_dataset(
time_range_start=time_start, time_range_start=time_start,
time_range_end=time_end, time_range_end=time_end,
hunt_id=hunt_id, hunt_id=hunt_id,
artifact_type=artifact_type,
processing_status="processing",
) )
await repo.bulk_insert_rows( await repo.bulk_insert_rows(
@@ -162,9 +180,88 @@ async def upload_dataset(
logger.info( logger.info(
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, " f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected" f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
f"artifact={artifact_type}"
) )
# -- Queue full processing pipeline --
jobs_queued = []
task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
return UploadResponse( return UploadResponse(
id=dataset.id, id=dataset.id,
name=dataset.name, name=dataset.name,
@@ -173,7 +270,10 @@ async def upload_dataset(
column_types=column_types, column_types=column_types,
normalized_columns=column_mapping, normalized_columns=column_mapping,
ioc_columns=ioc_columns, ioc_columns=ioc_columns,
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected", artifact_type=artifact_type,
processing_status="processing",
jobs_queued=jobs_queued,
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
) )
@@ -208,6 +308,8 @@ async def list_datasets(
delimiter=ds.delimiter, delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id, hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(), created_at=ds.created_at.isoformat(),
) )
@@ -244,6 +346,8 @@ async def get_dataset(
delimiter=ds.delimiter, delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id, hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(), created_at=ds.created_at.isoformat(),
) )
@@ -292,6 +396,7 @@ async def delete_dataset(
deleted = await repo.delete_dataset(dataset_id) deleted = await repo.delete_dataset(dataset_id)
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found") raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id} return {"message": "Dataset deleted", "id": dataset_id}

View File

@@ -8,16 +8,15 @@ from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db from app.db import get_db
from app.db.models import Hunt, Conversation, Message from app.db.models import Hunt, Dataset, ProcessingTask
from app.services.job_queue import job_queue
from app.services.host_inventory import inventory_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/hunts", tags=["hunts"]) router = APIRouter(prefix="/api/hunts", tags=["hunts"])
# ── Models ────────────────────────────────────────────────────────────
class HuntCreate(BaseModel): class HuntCreate(BaseModel):
name: str = Field(..., max_length=256) name: str = Field(..., max_length=256)
description: str | None = None description: str | None = None
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
class HuntUpdate(BaseModel): class HuntUpdate(BaseModel):
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
status: str | None = None # active | closed | archived status: str | None = None
class HuntResponse(BaseModel): class HuntResponse(BaseModel):
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
total: int total: int
# ── Routes ──────────────────────────────────────────────────────────── class HuntProgressResponse(BaseModel):
hunt_id: str
status: str
progress_percent: float
dataset_total: int
dataset_completed: int
dataset_processing: int
dataset_errors: int
active_jobs: int
queued_jobs: int
network_status: str
stages: dict
@router.post("", response_model=HuntResponse, summary="Create a new hunt") @router.post("", response_model=HuntResponse, summary="Create a new hunt")
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
) )
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
hunt = await db.get(Hunt, hunt_id)
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
ds_rows = await db.execute(
select(Dataset.id, Dataset.processing_status)
.where(Dataset.hunt_id == hunt_id)
)
datasets = ds_rows.all()
dataset_ids = {row[0] for row in datasets}
dataset_total = len(datasets)
dataset_completed = sum(1 for _, st in datasets if st == "completed")
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
task_rows = await db.execute(
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
.where(ProcessingTask.hunt_id == hunt_id)
)
tasks = task_rows.all()
task_total = len(tasks)
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
task_running = sum(1 for _, st, _ in tasks if st == "running")
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
task_ratio = (task_done / task_total) if task_total > 0 else None
active_jobs = max(active_jobs_mem, task_running)
queued_jobs = max(queued_jobs_mem, task_queued)
stage_rollup: dict[str, dict] = {}
for stage, status, progress in tasks:
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
bucket["total"] += 1
if status in ("completed", "failed", "cancelled"):
bucket["done"] += 1
elif status == "running":
bucket["running"] += 1
elif status == "queued":
bucket["queued"] += 1
bucket["progress_sum"] += float(progress or 0.0)
for stage_name, bucket in stage_rollup.items():
total = max(1, bucket["total"])
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
if inventory_cache.get(hunt_id) is not None:
network_status = "ready"
network_ratio = 1.0
elif inventory_cache.is_building(hunt_id):
network_status = "building"
network_ratio = 0.5
else:
network_status = "none"
network_ratio = 0.0
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
if task_ratio is None:
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
else:
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
status = "ready"
if dataset_total == 0:
status = "idle"
elif progress_percent < 100:
status = "processing"
stages = {
"datasets": {
"total": dataset_total,
"completed": dataset_completed,
"processing": dataset_processing,
"errors": dataset_errors,
"percent": round(dataset_ratio * 100.0, 1),
},
"network": {
"status": network_status,
"percent": round(network_ratio * 100.0, 1),
},
"jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
"task_total": task_total,
"task_done": task_done,
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
},
"task_stages": stage_rollup,
}
return HuntProgressResponse(
hunt_id=hunt_id,
status=status,
progress_percent=progress_percent,
dataset_total=dataset_total,
dataset_completed=dataset_completed,
dataset_processing=dataset_processing,
dataset_errors=dataset_errors,
active_jobs=active_jobs,
queued_jobs=queued_jobs,
network_status=network_status,
stages=stages,
)
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt") @router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
async def update_hunt( async def update_hunt(
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db) hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)

View File

@@ -1,25 +1,21 @@
"""API routes for AUP keyword themes, keyword CRUD, and scanning.""" """API routes for AUP keyword themes, keyword CRUD, and scanning."""
import logging import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select, func, delete from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db from app.db import get_db
from app.db.models import KeywordTheme, Keyword from app.db.models import KeywordTheme, Keyword
from app.services.scanner import KeywordScanner from app.services.scanner import KeywordScanner, keyword_scan_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/keywords", tags=["keywords"]) router = APIRouter(prefix="/api/keywords", tags=["keywords"])
# ── Pydantic schemas ──────────────────────────────────────────────────
class ThemeCreate(BaseModel): class ThemeCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=128) name: str = Field(..., min_length=1, max_length=128)
color: str = Field(default="#9e9e9e", max_length=16) color: str = Field(default="#9e9e9e", max_length=16)
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
class ScanRequest(BaseModel): class ScanRequest(BaseModel):
dataset_ids: list[str] | None = None # None → all datasets dataset_ids: list[str] | None = None
theme_ids: list[str] | None = None # None → all enabled themes theme_ids: list[str] | None = None
scan_hunts: bool = True scan_hunts: bool = False
scan_annotations: bool = True scan_annotations: bool = False
scan_messages: bool = True scan_messages: bool = False
prefer_cache: bool = True
force_rescan: bool = False
class ScanHit(BaseModel): class ScanHit(BaseModel):
theme_name: str theme_name: str
theme_color: str theme_color: str
keyword: str keyword: str
source_type: str # dataset_row | hunt | annotation | message source_type: str
source_id: str | int source_id: str | int
field: str field: str
matched_value: str matched_value: str
row_index: int | None = None row_index: int | None = None
dataset_name: str | None = None dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
class ScanResponse(BaseModel): class ScanResponse(BaseModel):
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
themes_scanned: int themes_scanned: int
keywords_scanned: int keywords_scanned: int
rows_scanned: int rows_scanned: int
cache_used: bool = False
cache_status: str = "miss"
# ── Helpers ─────────────────────────────────────────────────────────── cached_at: str | None = None
def _theme_to_out(t: KeywordTheme) -> ThemeOut: def _theme_to_out(t: KeywordTheme) -> ThemeOut:
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
) )
# ── Theme CRUD ──────────────────────────────────────────────────────── def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
hits: list[dict] = []
total_rows = 0
cached_at: str | None = None
for entry in entries:
result = entry["result"]
total_rows += int(result.get("rows_scanned", 0) or 0)
if entry.get("built_at"):
if not cached_at or entry["built_at"] > cached_at:
cached_at = entry["built_at"]
for h in result.get("hits", []):
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
continue
hits.append(h)
return {
"total_hits": len(hits),
"hits": hits,
"rows_scanned": total_rows,
"cached_at": cached_at,
}
@router.get("/themes", response_model=ThemeListResponse) @router.get("/themes", response_model=ThemeListResponse)
async def list_themes(db: AsyncSession = Depends(get_db)): async def list_themes(db: AsyncSession = Depends(get_db)):
"""List all keyword themes with their keywords.""" result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
result = await db.execute(
select(KeywordTheme).order_by(KeywordTheme.name)
)
themes = result.scalars().all() themes = result.scalars().all()
return ThemeListResponse( return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
themes=[_theme_to_out(t) for t in themes],
total=len(themes),
)
@router.post("/themes", response_model=ThemeOut, status_code=201) @router.post("/themes", response_model=ThemeOut, status_code=201)
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)): async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
"""Create a new keyword theme.""" exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
exists = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
)
if exists: if exists:
raise HTTPException(409, f"Theme '{body.name}' already exists") raise HTTPException(409, f"Theme '{body.name}' already exists")
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled) theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
db.add(theme) db.add(theme)
await db.flush() await db.flush()
await db.refresh(theme) await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme) return _theme_to_out(theme)
@router.put("/themes/{theme_id}", response_model=ThemeOut) @router.put("/themes/{theme_id}", response_model=ThemeOut)
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)): async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
"""Update theme name, color, or enabled status."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
if body.name is not None: if body.name is not None:
# check uniqueness
dup = await db.scalar( dup = await db.scalar(
select(KeywordTheme.id).where( select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
)
) )
if dup: if dup:
raise HTTPException(409, f"Theme '{body.name}' already exists") raise HTTPException(409, f"Theme '{body.name}' already exists")
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
theme.enabled = body.enabled theme.enabled = body.enabled
await db.flush() await db.flush()
await db.refresh(theme) await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme) return _theme_to_out(theme)
@router.delete("/themes/{theme_id}", status_code=204) @router.delete("/themes/{theme_id}", status_code=204)
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)): async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a theme and all its keywords."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
await db.delete(theme) await db.delete(theme)
keyword_scan_cache.clear()
# ── Keyword CRUD ──────────────────────────────────────────────────────
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201) @router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)): async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
"""Add a single keyword to a theme."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
db.add(kw) db.add(kw)
await db.flush() await db.flush()
await db.refresh(kw) await db.refresh(kw)
keyword_scan_cache.clear()
return KeywordOut( return KeywordOut(
id=kw.id, theme_id=kw.theme_id, value=kw.value, id=kw.id, theme_id=kw.theme_id, value=kw.value,
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(), is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201) @router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)): async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
"""Add multiple keywords to a theme at once."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex)) db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
added += 1 added += 1
await db.flush() await db.flush()
keyword_scan_cache.clear()
return {"added": added, "theme_id": theme_id} return {"added": added, "theme_id": theme_id}
@router.delete("/keywords/{keyword_id}", status_code=204) @router.delete("/keywords/{keyword_id}", status_code=204)
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)): async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
"""Delete a single keyword."""
kw = await db.get(Keyword, keyword_id) kw = await db.get(Keyword, keyword_id)
if not kw: if not kw:
raise HTTPException(404, "Keyword not found") raise HTTPException(404, "Keyword not found")
await db.delete(kw) await db.delete(kw)
keyword_scan_cache.clear()
# ── Scan endpoints ────────────────────────────────────────────────────
@router.post("/scan", response_model=ScanResponse) @router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
"""Run AUP keyword scan across selected data sources."""
scanner = KeywordScanner(db) scanner = KeywordScanner(db)
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
return {
"total_hits": 0,
"hits": [],
"themes_scanned": 0,
"keywords_scanned": 0,
"rows_scanned": 0,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
can_use_cache = (
body.prefer_cache
and not body.force_rescan
and bool(body.dataset_ids)
and not body.scan_hunts
and not body.scan_annotations
and not body.scan_messages
)
if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
if missing:
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
merged = _merge_cached_results(
cached_entries + [{"result": partial, "built_at": None}],
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
result = await scanner.scan( result = await scanner.scan(
dataset_ids=body.dataset_ids, dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids, theme_ids=body.theme_ids,
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
scan_annotations=body.scan_annotations, scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages, scan_messages=body.scan_messages,
) )
return result
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
@router.get("/scan/quick", response_model=ScanResponse) @router.get("/scan/quick", response_model=ScanResponse)
@@ -251,7 +326,22 @@ async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"), dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Quick scan a single dataset with all enabled themes.""" entry = keyword_scan_cache.get(dataset_id)
if entry is not None:
result = entry.result
return {
**result,
"cache_used": True,
"cache_status": "hit",
"cached_at": entry.built_at,
}
scanner = KeywordScanner(db) scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id]) result = await scanner.scan(dataset_ids=[dataset_id])
return result keyword_scan_cache.put(dataset_id, result)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}

View File

@@ -0,0 +1,146 @@
"""API routes for MITRE ATT&CK coverage visualization."""
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
# Canonical MITRE ATT&CK tactics in kill-chain order
TACTICS = [
"Reconnaissance", "Resource Development", "Initial Access",
"Execution", "Persistence", "Privilege Escalation",
"Defense Evasion", "Credential Access", "Discovery",
"Lateral Movement", "Collection", "Command and Control",
"Exfiltration", "Impact",
]
# Simplified technique-to-tactic mapping (top techniques)
TECHNIQUE_TACTIC: dict[str, str] = {
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
"T1053": "Persistence", "T1053.005": "Persistence",
"T1547": "Persistence", "T1547.001": "Persistence",
"T1543": "Persistence", "T1543.003": "Persistence",
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
"T1070.004": "Defense Evasion",
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
"T1003": "Credential Access", "T1003.001": "Credential Access",
"T1110": "Credential Access", "T1558": "Credential Access",
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
"T1570": "Lateral Movement",
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
"T1071": "Command and Control", "T1071.001": "Command and Control",
"T1105": "Command and Control", "T1572": "Command and Control",
"T1095": "Command and Control",
"T1048": "Exfiltration", "T1041": "Exfiltration",
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
"T1566": "Initial Access", "T1566.001": "Initial Access",
"T1566.002": "Initial Access",
"T1190": "Initial Access", "T1133": "Initial Access",
"T1195": "Initial Access", "T1195.002": "Initial Access",
}
def _get_tactic(technique_id: str) -> str:
"""Map a technique ID to its tactic."""
tech = technique_id.strip().upper()
if tech in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[tech]
# Try parent technique
if "." in tech:
parent = tech.split(".")[0]
if parent in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[parent]
return "Unknown"
@router.get("/coverage")
async def get_mitre_coverage(
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
techniques: dict[str, dict] = {}
# Collect from triage results
triage_q = select(TriageResult)
if hunt_id:
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
result = await db.execute(triage_q.limit(500))
for t in result.scalars().all():
for tech in (t.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
# Collect from host profiles
profile_q = select(HostProfile)
if hunt_id:
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
result = await db.execute(profile_q.limit(200))
for p in result.scalars().all():
for tech in (p.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
# Collect from hypotheses
hyp_q = select(Hypothesis)
if hunt_id:
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
result = await db.execute(hyp_q.limit(200))
for h in result.scalars().all():
tech = h.mitre_technique
if tech:
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
# Build tactic-grouped response
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
tactic_groups["Unknown"] = []
for tech in techniques.values():
tactic = tech["tactic"]
if tactic not in tactic_groups:
tactic_groups[tactic] = []
tactic_groups[tactic].append(tech)
total_techniques = len(techniques)
total_detections = sum(t["count"] for t in techniques.values())
return {
"tactics": TACTICS,
"technique_count": total_techniques,
"detection_count": total_detections,
"tactic_coverage": {
t: {"techniques": techs, "count": len(techs)}
for t, techs in tactic_groups.items()
if techs
},
"all_techniques": list(techniques.values()),
}

View File

@@ -1,19 +1,193 @@
<<<<<<< HEAD
"""Network topology API - host inventory endpoint with background caching."""
=======
"""API routes for Network Picture — deduplicated host inventory.""" """API routes for Network Picture — deduplicated host inventory."""
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
<<<<<<< HEAD
from fastapi.responses import JSONResponse
=======
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db from app.db import get_db
<<<<<<< HEAD
from app.services.host_inventory import build_host_inventory, inventory_cache
from app.services.job_queue import job_queue, JobType
=======
from app.services.network_inventory import build_network_picture from app.services.network_inventory import build_network_picture
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/network", tags=["network"]) router = APIRouter(prefix="/api/network", tags=["network"])
<<<<<<< HEAD
@router.get("/host-inventory")
async def get_host_inventory(
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
force: bool = Query(False, description="Force rebuild, ignoring cache"),
db: AsyncSession = Depends(get_db),
):
"""Return a deduplicated host inventory for the hunt.
Returns instantly from cache if available (pre-built after upload or on startup).
If cache is cold, triggers a background build and returns 202 so the
frontend can poll /inventory-status and re-request when ready.
"""
# Force rebuild: invalidate cache, queue background job, return 202
if force:
inventory_cache.invalidate(hunt_id)
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Rebuild queued"},
)
# Try cache first
cached = inventory_cache.get(hunt_id)
if cached is not None:
logger.info(f"Serving cached host inventory for {hunt_id}")
return cached
# Cache miss: trigger background build instead of blocking for 90+ seconds
if not inventory_cache.is_building(hunt_id):
logger.info(f"Cache miss for {hunt_id}, triggering background build")
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Inventory is being built in the background"},
)
def _build_summary(inv: dict, top_n: int = 20) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
return {
"stats": inv.get("stats", {}),
"top_hosts": [
{
"id": h.get("id"),
"hostname": h.get("hostname"),
"row_count": h.get("row_count", 0),
"ip_count": len(h.get("ips", [])),
"user_count": len(h.get("users", [])),
}
for h in top_hosts
],
"top_edges": top_edges,
}
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
if node_id:
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
ids = {node_id}
for c in rel_edges:
ids.add(c.get("source"))
ids.add(c.get("target"))
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
else:
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
allowed = {h.get("id") for h in rel_hosts}
rel_edges = [
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
if c.get("source") in allowed and c.get("target") in allowed
][:max_edges]
return {
"hosts": rel_hosts,
"connections": rel_edges,
"stats": {
**inv.get("stats", {}),
"subgraph_hosts": len(rel_hosts),
"subgraph_connections": len(rel_edges),
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
},
}
@router.get("/summary")
async def get_inventory_summary(
hunt_id: str = Query(..., description="Hunt ID"),
top_n: int = Query(20, ge=1, le=200),
):
"""Return a lightweight summary view for large hunts."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_summary(cached, top_n=top_n)
@router.get("/subgraph")
async def get_inventory_subgraph(
hunt_id: str = Query(..., description="Hunt ID"),
node_id: str | None = Query(None, description="Optional focal node"),
max_hosts: int = Query(200, ge=1, le=5000),
max_edges: int = Query(1500, ge=1, le=20000),
):
"""Return a bounded subgraph for scale-safe rendering."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
@router.get("/inventory-status")
async def get_inventory_status(
hunt_id: str = Query(..., description="Hunt ID to check"),
):
"""Check whether pre-computed host inventory is ready for a hunt.
Returns: { status: "ready" | "building" | "none" }
"""
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
@router.post("/rebuild-inventory")
async def trigger_rebuild(
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
):
"""Trigger a background rebuild of the host inventory cache."""
inventory_cache.invalidate(hunt_id)
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return {"job_id": job.id, "status": "queued"}
=======
# ── Response models ─────────────────────────────────────────────────── # ── Response models ───────────────────────────────────────────────────
@@ -67,3 +241,4 @@ async def get_network_picture(
result = await build_network_picture(db, hunt_id) result = await build_network_picture(db, hunt_id)
return result return result
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -0,0 +1,217 @@
"""API routes for investigation playbooks."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Playbook, PlaybookStep
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
# -- Request / Response schemas ---
class StepCreate(BaseModel):
title: str
description: str | None = None
step_type: str = "manual"
target_route: str | None = None
class PlaybookCreate(BaseModel):
name: str
description: str | None = None
hunt_id: str | None = None
is_template: bool = False
steps: list[StepCreate] = []
class PlaybookUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None
class StepUpdate(BaseModel):
is_completed: bool | None = None
notes: str | None = None
# -- Default investigation templates ---
DEFAULT_TEMPLATES = [
{
"name": "Standard Threat Hunt",
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
"steps": [
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
],
},
{
"name": "Incident Response Triage",
"description": "Fast-track workflow for active incident response.",
"steps": [
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
],
},
]
# -- Routes ---
@router.get("")
async def list_playbooks(
include_templates: bool = True,
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(Playbook)
if hunt_id:
q = q.where(Playbook.hunt_id == hunt_id)
if not include_templates:
q = q.where(Playbook.is_template == False)
q = q.order_by(Playbook.created_at.desc())
result = await db.execute(q.limit(100))
playbooks = result.scalars().all()
return {"playbooks": [
{
"id": p.id, "name": p.name, "description": p.description,
"is_template": p.is_template, "hunt_id": p.hunt_id,
"status": p.status,
"total_steps": len(p.steps),
"completed_steps": sum(1 for s in p.steps if s.is_completed),
"created_at": p.created_at.isoformat() if p.created_at else None,
}
for p in playbooks
]}
@router.get("/templates")
async def get_templates():
"""Return built-in investigation templates."""
return {"templates": DEFAULT_TEMPLATES}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
pb = Playbook(
name=body.name,
description=body.description,
hunt_id=body.hunt_id,
is_template=body.is_template,
)
db.add(pb)
await db.flush()
created_steps = []
for i, step in enumerate(body.steps):
s = PlaybookStep(
playbook_id=pb.id,
order_index=i,
title=step.title,
description=step.description,
step_type=step.step_type,
target_route=step.target_route,
)
db.add(s)
created_steps.append(s)
await db.flush()
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
"steps": [
{"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route, "is_completed": False}
for s in created_steps
],
}
@router.get("/{playbook_id}")
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
"status": pb.status,
"created_at": pb.created_at.isoformat() if pb.created_at else None,
"steps": [
{
"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route,
"is_completed": s.is_completed,
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
"notes": s.notes,
}
for s in sorted(pb.steps, key=lambda x: x.order_index)
],
}
@router.put("/{playbook_id}")
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
if body.name is not None:
pb.name = body.name
if body.description is not None:
pb.description = body.description
if body.status is not None:
pb.status = body.status
return {"status": "updated"}
@router.delete("/{playbook_id}")
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
await db.delete(pb)
return {"status": "deleted"}
@router.put("/steps/{step_id}")
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
step = result.scalar_one_or_none()
if not step:
raise HTTPException(status_code=404, detail="Step not found")
if body.is_completed is not None:
step.is_completed = body.is_completed
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
if body.notes is not None:
step.notes = body.notes
return {"status": "updated", "is_completed": step.is_completed}

View File

@@ -0,0 +1,164 @@
"""API routes for saved searches and bookmarked queries."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import SavedSearch
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
class SearchCreate(BaseModel):
name: str
description: str | None = None
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
query_params: dict
threshold: float | None = None
class SearchUpdate(BaseModel):
name: str | None = None
description: str | None = None
query_params: dict | None = None
threshold: float | None = None
@router.get("")
async def list_searches(
search_type: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
if search_type:
q = q.where(SavedSearch.search_type == search_type)
result = await db.execute(q.limit(100))
searches = result.scalars().all()
return {"searches": [
{
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in searches
]}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
s = SavedSearch(
name=body.name,
description=body.description,
search_type=body.search_type,
query_params=body.query_params,
threshold=body.threshold,
)
db.add(s)
await db.flush()
return {
"id": s.id, "name": s.name, "search_type": s.search_type,
"query_params": s.query_params, "threshold": s.threshold,
}
@router.get("/{search_id}")
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
return {
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
@router.put("/{search_id}")
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
if body.name is not None:
s.name = body.name
if body.description is not None:
s.description = body.description
if body.query_params is not None:
s.query_params = body.query_params
if body.threshold is not None:
s.threshold = body.threshold
return {"status": "updated"}
@router.delete("/{search_id}")
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
await db.delete(s)
return {"status": "deleted"}
@router.post("/{search_id}/run")
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
"""Execute a saved search and return results with delta from last run."""
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
previous_count = s.last_result_count or 0
results = []
count = 0
if s.search_type == "ioc_search":
from app.db.models import EnrichmentResult
ioc_value = s.query_params.get("ioc_value", "")
if ioc_value:
q = select(EnrichmentResult).where(
EnrichmentResult.ioc_value.contains(ioc_value)
)
res = await db.execute(q.limit(100))
for er in res.scalars().all():
results.append({
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
"source": er.source, "verdict": er.verdict,
})
count = len(results)
elif s.search_type == "keyword_scan":
from app.db.models import KeywordTheme
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
themes = res.scalars().all()
count = sum(len(t.keywords) for t in themes)
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
# Update last run metadata
s.last_run_at = datetime.now(timezone.utc)
s.last_result_count = count
delta = count - previous_count
return {
"search_id": s.id, "search_name": s.name,
"search_type": s.search_type,
"result_count": count,
"previous_count": previous_count,
"delta": delta,
"results": results[:50],
}

View File

@@ -0,0 +1,184 @@
"""STIX 2.1 export endpoint.
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
STIX 2.1 Bundle JSON download. No external dependencies required we
build the JSON directly following the OASIS STIX 2.1 spec.
"""
import json
import uuid
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
EnrichmentResult, HuntReport,
)
router = APIRouter(prefix="/api/export", tags=["export"])
STIX_SPEC_VERSION = "2.1"
def _stix_id(stype: str) -> str:
return f"{stype}--{uuid.uuid4()}"
def _now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
def _build_identity(hunt_name: str) -> dict:
return {
"type": "identity",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("identity"),
"created": _now_iso(),
"modified": _now_iso(),
"name": f"ThreatHunt - {hunt_name}",
"identity_class": "system",
}
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
pattern_map = {
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
"domain": f"[domain-name:value = '{ioc_value}']",
"url": f"[url:value = '{ioc_value}']",
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
"email": f"[email-addr:value = '{ioc_value}']",
}
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
now = _now_iso()
return {
"type": "indicator",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("indicator"),
"created": now,
"modified": now,
"name": f"{ioc_type}: {ioc_value}",
"pattern": pattern,
"pattern_type": "stix",
"valid_from": now,
"created_by_ref": identity_id,
"labels": [verdict or "suspicious"],
}
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "attack-pattern",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("attack-pattern"),
"created": now,
"modified": now,
"name": technique_id,
"created_by_ref": identity_id,
"external_references": [
{
"source_name": "mitre-attack",
"external_id": technique_id,
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
}
],
}
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "report",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("report"),
"created": now,
"modified": now,
"name": hyp.title,
"description": hyp.description or "",
"published": now,
"created_by_ref": identity_id,
"labels": ["threat-hunt-hypothesis"],
"object_refs": [],
}
@router.get("/stix/{hunt_id}")
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
# Fetch hunt
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
if not hunt:
raise HTTPException(404, "Hunt not found")
identity = _build_identity(hunt.name)
objects: list[dict] = [identity]
seen_techniques: set[str] = set()
seen_iocs: set[str] = set()
# Gather IOCs from enrichment results for hunt's datasets
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
ds_ids = [r[0] for r in datasets_q.all()]
if ds_ids:
enrichments = (await db.execute(
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
)).scalars().all()
for e in enrichments:
key = f"{e.ioc_type}:{e.ioc_value}"
if key not in seen_iocs:
seen_iocs.add(key)
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
# Gather techniques from triage results
triages = (await db.execute(
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
)).scalars().all()
for t in triages:
for tech in (t.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather techniques from host profiles
profiles = (await db.execute(
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
)).scalars().all()
for p in profiles:
for tech in (p.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather hypotheses
hypos = (await db.execute(
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
)).scalars().all()
for h in hypos:
objects.append(_hypothesis_to_report(h, identity["id"]))
if h.mitre_technique and h.mitre_technique not in seen_techniques:
seen_techniques.add(h.mitre_technique)
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
bundle = {
"type": "bundle",
"id": _stix_id("bundle"),
"objects": objects,
}
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
return Response(
content=json.dumps(bundle, indent=2),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

View File

@@ -0,0 +1,128 @@
"""API routes for forensic timeline visualization."""
import logging
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Dataset, DatasetRow, Hunt
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
def _parse_timestamp(val: str | None) -> str | None:
"""Try to parse a timestamp string, return ISO format or None."""
if not val:
return None
val = str(val).strip()
if not val:
return None
# Try common formats
for fmt in [
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
]:
try:
return datetime.strptime(val, fmt).isoformat() + "Z"
except ValueError:
continue
return None
# Columns likely to contain timestamps
TIME_COLUMNS = {
"timestamp", "time", "datetime", "date", "created", "modified",
"eventtime", "event_time", "start_time", "end_time",
"lastmodified", "last_modified", "created_at", "updated_at",
"mtime", "atime", "ctime", "btime",
"timecreated", "timegenerated", "sourcetime",
}
@router.get("/hunt/{hunt_id}")
async def get_hunt_timeline(
hunt_id: str,
limit: int = 2000,
db: AsyncSession = Depends(get_db),
):
"""Build a timeline of events across all datasets in a hunt."""
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
datasets = result.scalars().all()
if not datasets:
return {"hunt_id": hunt_id, "events": [], "datasets": []}
events = []
dataset_info = []
for ds in datasets:
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
dataset_info.append({
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
"row_count": ds.row_count,
})
# Find time columns for this dataset
schema = ds.column_schema or {}
time_cols = []
for col in (ds.normalized_columns or {}).values():
if col.lower() in TIME_COLUMNS:
time_cols.append(col)
if not time_cols:
for col in schema:
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
time_cols.append(col)
if not time_cols:
continue
# Fetch rows
rows_result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.limit(limit // max(len(datasets), 1))
)
for r in rows_result.scalars().all():
data = r.normalized_data or r.data
ts = None
for tc in time_cols:
ts = _parse_timestamp(data.get(tc))
if ts:
break
if ts:
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
events.append({
"timestamp": ts,
"dataset_id": ds.id,
"dataset_name": ds.name,
"artifact_type": artifact_type,
"row_index": r.row_index,
"hostname": str(hostname)[:128],
"process": str(process)[:128],
"summary": str(summary)[:256],
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
})
# Sort by timestamp
events.sort(key=lambda e: e["timestamp"])
return {
"hunt_id": hunt_id,
"hunt_name": hunt.name,
"event_count": len(events),
"datasets": dataset_info,
"events": events[:limit],
}

View File

@@ -1,4 +1,4 @@
"""Application configuration single source of truth for all settings. """Application configuration - single source of truth for all settings.
Loads from environment variables with sensible defaults for local dev. Loads from environment variables with sensible defaults for local dev.
""" """
@@ -13,12 +13,12 @@ from pydantic import Field
class AppConfig(BaseSettings): class AppConfig(BaseSettings):
"""Central configuration for the entire ThreatHunt application.""" """Central configuration for the entire ThreatHunt application."""
# ── General ──────────────────────────────────────────────────────── # -- General --------------------------------------------------------
APP_NAME: str = "ThreatHunt" APP_NAME: str = "ThreatHunt"
APP_VERSION: str = "0.4.0" APP_VERSION: str = "0.4.0"
DEBUG: bool = Field(default=False, description="Enable debug mode") DEBUG: bool = Field(default=False, description="Enable debug mode")
# ── Database ─────────────────────────────────────────────────────── # -- Database -------------------------------------------------------
DATABASE_URL: str = Field( DATABASE_URL: str = Field(
default="sqlite+aiosqlite:///./threathunt.db", default="sqlite+aiosqlite:///./threathunt.db",
description="Async SQLAlchemy database URL. " description="Async SQLAlchemy database URL. "
@@ -26,17 +26,17 @@ class AppConfig(BaseSettings):
"postgresql+asyncpg://user:pass@host/db for production.", "postgresql+asyncpg://user:pass@host/db for production.",
) )
# ── CORS ─────────────────────────────────────────────────────────── # -- CORS -----------------------------------------------------------
ALLOWED_ORIGINS: str = Field( ALLOWED_ORIGINS: str = Field(
default="http://localhost:3000,http://localhost:8000", default="http://localhost:3000,http://localhost:8000",
description="Comma-separated list of allowed CORS origins", description="Comma-separated list of allowed CORS origins",
) )
# ── File uploads ─────────────────────────────────────────────────── # -- File uploads ---------------------------------------------------
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB") MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files") UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
# ── LLM Cluster Wile & Roadrunner ──────────────────────────────── # -- LLM Cluster - Wile & Roadrunner --------------------------------
OPENWEBUI_URL: str = Field( OPENWEBUI_URL: str = Field(
default="https://ai.guapo613.beer", default="https://ai.guapo613.beer",
description="Open WebUI cluster endpoint (OpenAI-compatible API)", description="Open WebUI cluster endpoint (OpenAI-compatible API)",
@@ -58,7 +58,7 @@ class AppConfig(BaseSettings):
default=11434, description="Ollama port on Roadrunner" default=11434, description="Ollama port on Roadrunner"
) )
# ── LLM Routing defaults ────────────────────────────────────────── # -- LLM Routing defaults ------------------------------------------
DEFAULT_FAST_MODEL: str = Field( DEFAULT_FAST_MODEL: str = Field(
default="llama3.1:latest", default="llama3.1:latest",
description="Default model for quick chat / simple queries", description="Default model for quick chat / simple queries",
@@ -80,18 +80,18 @@ class AppConfig(BaseSettings):
description="Default embedding model", description="Default embedding model",
) )
# ── Agent behaviour ─────────────────────────────────────────────── # -- Agent behaviour ------------------------------------------------
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response") AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance") AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context") AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns") FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
# ── Enrichment API keys ─────────────────────────────────────────── # -- Enrichment API keys --------------------------------------------
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key") VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key") ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
SHODAN_API_KEY: str = Field(default="", description="Shodan API key") SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
# ── Auth ────────────────────────────────────────────────────────── # -- Auth -----------------------------------------------------------
JWT_SECRET: str = Field( JWT_SECRET: str = Field(
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET", default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
description="Secret for JWT signing", description="Secret for JWT signing",
@@ -99,6 +99,73 @@ class AppConfig(BaseSettings):
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime") JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime") JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
# -- Triage settings ------------------------------------------------
TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch")
TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field(
default=200, description="Stop triage after this many suspicious rows"
)
TRIAGE_ESCALATION_THRESHOLD: float = Field(
default=5.0, description="Risk score threshold for escalation counting"
)
# -- Host profiler settings -----------------------------------------
HOST_PROFILE_CONCURRENCY: int = Field(
default=3, description="Max concurrent host profile LLM calls"
)
# -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=120000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
# -- Job queue settings ----------------------------------------------
JOB_QUEUE_MAX_BACKLOG: int = Field(
default=2000, description="Soft cap for queued background jobs"
)
JOB_QUEUE_RETAIN_COMPLETED: int = Field(
default=3000, description="Maximum completed/failed jobs to retain in memory"
)
JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(
default=60, description="How often to run in-memory job cleanup"
)
JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(
default=3600, description="Age threshold for in-memory completed job cleanup"
)
# -- Startup throttling ------------------------------------------------
STARTUP_WARMUP_MAX_HUNTS: int = Field(
default=5, description="Max hunts to warm inventory cache for at startup"
)
STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
default=True,
description="Mark stale queued/running processing tasks as failed on startup",
)
# -- Network API scale guards -----------------------------------------
NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(
default=400, description="Hard cap for hosts returned by network subgraph endpoint"
)
NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=5000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
default=120000,
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
default=120000,
description="Max unique connection tuples retained during host inventory build",
)
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"} model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
@property @property
@@ -119,3 +186,4 @@ class AppConfig(BaseSettings):
settings = AppConfig() settings = AppConfig()

View File

@@ -21,9 +21,14 @@ _engine_kwargs: dict = dict(
) )
if _is_sqlite: if _is_sqlite:
_engine_kwargs["connect_args"] = {"timeout": 30} _engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
_engine_kwargs["pool_size"] = 1 # NullPool: each session gets its own connection.
_engine_kwargs["max_overflow"] = 0 # Combined with WAL mode, this allows concurrent reads while a write is in progress.
from sqlalchemy.pool import NullPool
_engine_kwargs["poolclass"] = NullPool
else:
_engine_kwargs["pool_size"] = 5
_engine_kwargs["max_overflow"] = 10
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs) engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
@@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record):
if _is_sqlite: if _is_sqlite:
cursor = dbapi_conn.cursor() cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA busy_timeout=5000") cursor.execute("PRAGMA busy_timeout=30000")
cursor.execute("PRAGMA synchronous=NORMAL") cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close() cursor.close()
@@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker(
) )
# Alias expected by other modules
async_session = async_session_factory
class Base(DeclarativeBase): class Base(DeclarativeBase):
"""Base class for all ORM models.""" """Base class for all ORM models."""
pass pass
@@ -83,5 +92,5 @@ async def init_db() -> None:
async def dispose_db() -> None: async def dispose_db() -> None:
"""Dispose of the engine connection pool.""" """Dispose of the engine on shutdown."""
await engine.dispose() await engine.dispose()

View File

@@ -1,4 +1,4 @@
"""SQLAlchemy ORM models for ThreatHunt. """SQLAlchemy ORM models for ThreatHunt.
All persistent entities: datasets, hunts, conversations, annotations, All persistent entities: datasets, hunts, conversations, annotations,
hypotheses, enrichment results, and users. hypotheses, enrichment results, and users.
@@ -44,6 +44,7 @@ class User(Base):
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False) hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
is_active: Mapped[bool] = mapped_column(Boolean, default=True) is_active: Mapped[bool] = mapped_column(Boolean, default=True)
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# relationships # relationships
@@ -544,3 +545,116 @@ class PlaybookRun(Base):
Index("ix_playbook_runs_hunt", "hunt_id"), Index("ix_playbook_runs_hunt", "hunt_id"),
Index("ix_playbook_runs_status", "status"), Index("ix_playbook_runs_status", "status"),
) )
<<<<<<< HEAD
anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# -- Persistent Processing Tasks (Phase 2) ---
class ProcessingTask(Base):
__tablename__ = "processing_tasks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
)
dataset_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
)
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
progress: Mapped[float] = mapped_column(Float, default=0.0)
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
)
# -- Playbook / Investigation Templates (Feature 3) ---
class Playbook(Base):
__tablename__ = "playbooks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_by: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
is_template: Mapped[bool] = mapped_column(Boolean, default=False)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
steps: Mapped[list["PlaybookStep"]] = relationship(
back_populates="playbook", lazy="selectin", cascade="all, delete-orphan",
order_by="PlaybookStep.order_index",
)
class PlaybookStep(Base):
__tablename__ = "playbook_steps"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
playbook_id: Mapped[str] = mapped_column(
String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
)
order_index: Mapped[int] = mapped_column(Integer, nullable=False)
title: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
step_type: Mapped[str] = mapped_column(String(32), default="manual")
target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
is_completed: Mapped[bool] = mapped_column(Boolean, default=False)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
playbook: Mapped["Playbook"] = relationship(back_populates="steps")
__table_args__ = (
Index("ix_playbook_steps_playbook", "playbook_id"),
)
# -- Saved Searches (Feature 5) ---
class SavedSearch(Base):
__tablename__ = "saved_searches"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
search_type: Mapped[str] = mapped_column(String(32), nullable=False)
query_params: Mapped[dict] = mapped_column(JSON, nullable=False)
threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
created_by: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
__table_args__ = (
Index("ix_saved_searches_type", "search_type"),
)
=======
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -22,10 +22,18 @@ from app.api.routes.reports import router as reports_router
from app.api.routes.auth import router as auth_router from app.api.routes.auth import router as auth_router
from app.api.routes.keywords import router as keywords_router from app.api.routes.keywords import router as keywords_router
from app.api.routes.network import router as network_router from app.api.routes.network import router as network_router
<<<<<<< HEAD
from app.api.routes.mitre import router as mitre_router
from app.api.routes.timeline import router as timeline_router
from app.api.routes.playbooks import router as playbooks_router
from app.api.routes.saved_searches import router as searches_router
from app.api.routes.stix_export import router as stix_router
=======
from app.api.routes.analysis import router as analysis_router from app.api.routes.analysis import router as analysis_router
from app.api.routes.cases import router as cases_router from app.api.routes.cases import router as cases_router
from app.api.routes.alerts import router as alerts_router from app.api.routes.alerts import router as alerts_router
from app.api.routes.notebooks import router as notebooks_router from app.api.routes.notebooks import router as notebooks_router
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,8 +50,101 @@ async def lifespan(app: FastAPI):
async with async_session_factory() as seed_db: async with async_session_factory() as seed_db:
await seed_defaults(seed_db) await seed_defaults(seed_db)
logger.info("AUP keyword defaults checked") logger.info("AUP keyword defaults checked")
<<<<<<< HEAD
# Start job queue
from app.services.job_queue import (
job_queue,
register_all_handlers,
reconcile_stale_processing_tasks,
JobType,
)
if settings.STARTUP_RECONCILE_STALE_TASKS:
reconciled = await reconcile_stale_processing_tasks()
if reconciled:
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
# Pre-warm host inventory cache for existing hunts
from app.services.host_inventory import inventory_cache
async with async_session_factory() as warm_db:
from sqlalchemy import select, func
from app.db.models import Hunt, Dataset
stmt = (
select(Hunt.id)
.join(Dataset, Dataset.hunt_id == Hunt.id)
.group_by(Hunt.id)
.having(func.count(Dataset.id) > 0)
)
result = await warm_db.execute(stmt)
hunt_ids = [row[0] for row in result.all()]
warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]
for hid in warm_hunts:
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)
if warm_hunts:
logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})")
# Check which datasets still need processing
# (no anomaly results = never fully processed)
async with async_session_factory() as reprocess_db:
from sqlalchemy import select, exists
from app.db.models import Dataset, AnomalyResult
# Find datasets that have zero anomaly results (pipeline never ran or failed)
has_anomaly = (
select(AnomalyResult.id)
.where(AnomalyResult.dataset_id == Dataset.id)
.limit(1)
.correlate(Dataset)
.exists()
)
stmt = select(Dataset.id).where(~has_anomaly)
result = await reprocess_db.execute(stmt)
unprocessed_ids = [row[0] for row in result.all()]
if unprocessed_ids:
to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]
for ds_id in to_reprocess:
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})")
async with async_session_factory() as update_db:
from sqlalchemy import update
from app.db.models import Dataset
await update_db.execute(
update(Dataset)
.where(Dataset.id.in_(to_reprocess))
.values(processing_status="processing")
)
await update_db.commit()
else:
logger.info("All datasets already processed - skipping startup pipeline")
# Start load balancer health loop
from app.services.load_balancer import lb
await lb.start_health_loop(interval=30.0)
logger.info("Load balancer health loop started")
yield
logger.info("Shutting down ...")
from app.services.job_queue import job_queue as jq
await jq.stop()
logger.info("Job queue stopped")
from app.services.load_balancer import lb as _lb
await _lb.stop_health_loop()
logger.info("Load balancer stopped")
=======
yield yield
logger.info("Shutting down …") logger.info("Shutting down …")
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
from app.agents.providers_v2 import cleanup_client from app.agents.providers_v2 import cleanup_client
from app.services.enrichment import enrichment_engine from app.services.enrichment import enrichment_engine
await cleanup_client() await cleanup_client()
@@ -80,10 +181,18 @@ app.include_router(correlation_router)
app.include_router(reports_router) app.include_router(reports_router)
app.include_router(keywords_router) app.include_router(keywords_router)
app.include_router(network_router) app.include_router(network_router)
<<<<<<< HEAD
app.include_router(mitre_router)
app.include_router(timeline_router)
app.include_router(playbooks_router)
app.include_router(searches_router)
app.include_router(stix_router)
=======
app.include_router(analysis_router) app.include_router(analysis_router)
app.include_router(cases_router) app.include_router(cases_router)
app.include_router(alerts_router) app.include_router(alerts_router)
app.include_router(notebooks_router) app.include_router(notebooks_router)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
@app.get("/", tags=["health"]) @app.get("/", tags=["health"])
@@ -100,3 +209,15 @@ async def root():
"openwebui": settings.OPENWEBUI_URL, "openwebui": settings.OPENWEBUI_URL,
}, },
} }
<<<<<<< HEAD
@app.get("/health", tags=["health"])
async def health():
return {
"service": "ThreatHunt API",
"version": settings.APP_VERSION,
"status": "ok",
}
=======
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow from app.db.models import Dataset, DatasetRow
from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
return name or '' return name or ''
# In-memory host inventory cache
# Pre-computed results stored per hunt_id, built in background after upload.
import time as _time
class _InventoryCache:
"""Simple in-memory cache for pre-computed host inventories."""
def __init__(self):
self._data: dict[str, dict] = {} # hunt_id -> result dict
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
self._building: set[str] = set() # hunt_ids currently being built
def get(self, hunt_id: str) -> dict | None:
"""Return cached result if present. Never expires; only invalidated on new upload."""
return self._data.get(hunt_id)
def put(self, hunt_id: str, result: dict):
self._data[hunt_id] = result
self._timestamps[hunt_id] = _time.time()
self._building.discard(hunt_id)
logger.info(f"Cached host inventory for hunt {hunt_id} "
f"({result['stats']['total_hosts']} hosts)")
def invalidate(self, hunt_id: str):
self._data.pop(hunt_id, None)
self._timestamps.pop(hunt_id, None)
def is_building(self, hunt_id: str) -> bool:
return hunt_id in self._building
def set_building(self, hunt_id: str):
self._building.add(hunt_id)
def clear_building(self, hunt_id: str):
self._building.discard(hunt_id)
def status(self, hunt_id: str) -> str:
if hunt_id in self._building:
return "building"
if hunt_id in self._data:
return "ready"
return "none"
inventory_cache = _InventoryCache()
def _infer_os(fqdn: str) -> str: def _infer_os(fqdn: str) -> str:
u = fqdn.upper() u = fqdn.upper()
if 'W10-' in u or 'WIN10' in u: if 'W10-' in u or 'WIN10' in u:
@@ -151,33 +201,61 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
}} }}
hosts: dict[str, dict] = {} # fqdn -> host record hosts: dict[str, dict] = {} # fqdn -> host record
ip_to_host: dict[str, str] = {} # local-ip -> fqdn ip_to_host: dict[str, str] = {} # local-ip -> fqdn
connections: dict[tuple, int] = defaultdict(int) connections: dict[tuple, int] = defaultdict(int)
total_rows = 0 total_rows = 0
ds_with_hosts = 0 ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
for ds in all_datasets: for ds in all_datasets:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
cols = _identify_columns(ds) cols = _identify_columns(ds)
if not cols['fqdn'] and not cols['host_id']: if not cols['fqdn'] and not cols['host_id']:
continue continue
ds_with_hosts += 1 ds_with_hosts += 1
batch_size = 5000 batch_size = 5000
offset = 0 max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True: while True:
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
rr = await db.execute( rr = await db.execute(
select(DatasetRow) select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id) .where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index) .order_by(DatasetRow.row_index)
.offset(offset).limit(batch_size) .limit(batch_size)
) )
rows = rr.scalars().all() rows = rr.scalars().all()
if not rows: if not rows:
break break
for ro in rows: for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {} data = ro.data or {}
total_rows += 1 total_rows += 1
rows_scanned_this_dataset += 1
fqdn = '' fqdn = ''
for c in cols['fqdn']: for c in cols['fqdn']:
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
rport = _clean(data.get(pc)) rport = _clean(data.get(pc))
if rport: if rport:
break break
connections[(host_key, rip, rport)] += 1 conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
offset += batch_size if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory sampling for dataset %s (%d rows scanned)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size: if len(rows) < batch_size:
break break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts # Post-process hosts
for h in hosts.values(): for h in hosts.values():
if not h['os'] and h['fqdn']: if not h['os'] and h['fqdn']:
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
"total_rows_scanned": total_rows, "total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']), "hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']), "hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
}, },
} }

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import re
import json import json
import logging import logging
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
WILE_URL = f"{settings.wile_url}/api/generate" WILE_URL = f"{settings.wile_url}/api/generate"
# Velociraptor client IDs (C.hex) are not real hostnames
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
async def _get_triage_summary(db, dataset_id: str) -> str: async def _get_triage_summary(db, dataset_id: str) -> str:
result = await db.execute( result = await db.execute(
@@ -154,7 +158,7 @@ async def profile_host(
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level) logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
except Exception as e: except Exception as e:
logger.error("Failed to profile host %s: %s", hostname, e) logger.error("Failed to profile host %s: %r", hostname, e)
profile = HostProfile( profile = HostProfile(
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn, hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
risk_score=0.0, risk_level="unknown", risk_score=0.0, risk_level="unknown",
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
if h not in hostnames: if h not in hostnames:
hostnames[h] = data.get("fqdn") or data.get("Fqdn") hostnames[h] = data.get("fqdn") or data.get("Fqdn")
# Filter out Velociraptor client IDs - not real hostnames
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
skipped = len(hostnames) - len(real_hosts)
if skipped:
logger.info("Skipped %d Velociraptor client IDs", skipped)
hostnames = real_hosts
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id) logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY) semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)

View File

@@ -1,8 +1,8 @@
"""Async job queue for background AI tasks. """Async job queue for background AI tasks.
Manages triage, profiling, report generation, anomaly detection, Manages triage, profiling, report generation, anomaly detection,
and data queries as trackable jobs with status, progress, and keyword scanning, IOC extraction, and data queries as trackable
cancellation support. jobs with status, progress, and cancellation support.
""" """
from __future__ import annotations from __future__ import annotations
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Coroutine, Optional from typing import Any, Callable, Coroutine, Optional
from app.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,6 +34,18 @@ class JobType(str, Enum):
REPORT = "report" REPORT = "report"
ANOMALY = "anomaly" ANOMALY = "anomaly"
QUERY = "query" QUERY = "query"
HOST_INVENTORY = "host_inventory"
KEYWORD_SCAN = "keyword_scan"
IOC_EXTRACT = "ioc_extract"
# Job types that form the automatic upload pipeline
PIPELINE_JOB_TYPES = frozenset({
JobType.TRIAGE,
JobType.ANOMALY,
JobType.KEYWORD_SCAN,
JobType.IOC_EXTRACT,
})
@dataclass @dataclass
@@ -82,11 +96,7 @@ class Job:
class JobQueue: class JobQueue:
"""In-memory async job queue with concurrency control. """In-memory async job queue with concurrency control."""
Jobs are tracked by ID and can be listed, polled, or cancelled.
A configurable number of workers process jobs from the queue.
"""
def __init__(self, max_workers: int = 3): def __init__(self, max_workers: int = 3):
self._jobs: dict[str, Job] = {} self._jobs: dict[str, Job] = {}
@@ -95,47 +105,56 @@ class JobQueue:
self._workers: list[asyncio.Task] = [] self._workers: list[asyncio.Task] = []
self._handlers: dict[JobType, Callable] = {} self._handlers: dict[JobType, Callable] = {}
self._started = False self._started = False
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
self._cleanup_task: asyncio.Task | None = None
def register_handler( def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
self,
job_type: JobType,
handler: Callable[[Job], Coroutine],
):
"""Register an async handler for a job type.
Handler signature: async def handler(job: Job) -> Any
The handler can update job.progress and job.message during execution.
It should check job.is_cancelled periodically and return early.
"""
self._handlers[job_type] = handler self._handlers[job_type] = handler
logger.info(f"Registered handler for {job_type.value}") logger.info(f"Registered handler for {job_type.value}")
def on_completion(self, callback: Callable[[Job], Coroutine]):
"""Register a callback invoked after any job completes or fails."""
self._completion_callbacks.append(callback)
async def start(self): async def start(self):
"""Start worker tasks."""
if self._started: if self._started:
return return
self._started = True self._started = True
for i in range(self._max_workers): for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i)) task = asyncio.create_task(self._worker(i))
self._workers.append(task) self._workers.append(task)
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(f"Job queue started with {self._max_workers} workers") logger.info(f"Job queue started with {self._max_workers} workers")
async def stop(self): async def stop(self):
"""Stop all workers."""
self._started = False self._started = False
for w in self._workers: for w in self._workers:
w.cancel() w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True) await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear() self._workers.clear()
if self._cleanup_task:
self._cleanup_task.cancel()
await asyncio.gather(self._cleanup_task, return_exceptions=True)
self._cleanup_task = None
logger.info("Job queue stopped") logger.info("Job queue stopped")
def submit(self, job_type: JobType, **params) -> Job: def submit(self, job_type: JobType, **params) -> Job:
"""Submit a new job. Returns the Job object immediately.""" # Soft backpressure: prefer dedupe over queue amplification
job = Job( dedupe_job = self._find_active_duplicate(job_type, params)
id=str(uuid.uuid4()), if dedupe_job is not None:
job_type=job_type, logger.info(
params=params, f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
) )
return dedupe_job
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
logger.warning(
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
)
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job self._jobs[job.id] = job
self._queue.put_nowait(job.id) self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
@@ -144,6 +163,22 @@ class JobQueue:
def get_job(self, job_id: str) -> Job | None: def get_job(self, job_id: str) -> Job | None:
return self._jobs.get(job_id) return self._jobs.get(job_id)
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
"""Return queued/running job with same key workload to prevent duplicate storms."""
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
if not sig:
return None
for j in self._jobs.values():
if j.job_type != job_type:
continue
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
continue
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
if sig == other_sig:
return j
return None
def cancel_job(self, job_id: str) -> bool: def cancel_job(self, job_id: str) -> bool:
job = self._jobs.get(job_id) job = self._jobs.get(job_id)
if not job: if not job:
@@ -153,13 +188,7 @@ class JobQueue:
job.cancel() job.cancel()
return True return True
def list_jobs( def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
self,
status: JobStatus | None = None,
job_type: JobType | None = None,
limit: int = 50,
) -> list[dict]:
"""List jobs, newest first."""
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True) jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
if status: if status:
jobs = [j for j in jobs if j.status == status] jobs = [j for j in jobs if j.status == status]
@@ -168,7 +197,6 @@ class JobQueue:
return [j.to_dict() for j in jobs[:limit]] return [j.to_dict() for j in jobs[:limit]]
def get_stats(self) -> dict: def get_stats(self) -> dict:
"""Get queue statistics."""
by_status = {} by_status = {}
for j in self._jobs.values(): for j in self._jobs.values():
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1 by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
@@ -177,26 +205,58 @@ class JobQueue:
"queued": self._queue.qsize(), "queued": self._queue.qsize(),
"by_status": by_status, "by_status": by_status,
"workers": self._max_workers, "workers": self._max_workers,
"active_workers": sum( "active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
),
} }
def is_backlogged(self) -> bool:
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
def can_accept(self, reserve: int = 0) -> bool:
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
def cleanup(self, max_age_seconds: float = 3600): def cleanup(self, max_age_seconds: float = 3600):
"""Remove old completed/failed/cancelled jobs."""
now = time.time() now = time.time()
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
to_remove = [ to_remove = [
jid for jid, j in self._jobs.items() jid for jid, j in self._jobs.items()
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) if j.status in terminal_states and (now - j.created_at) > max_age_seconds
and (now - j.created_at) > max_age_seconds ]
# Also cap retained terminal jobs to avoid unbounded memory growth
terminal_jobs = sorted(
[j for j in self._jobs.values() if j.status in terminal_states],
key=lambda j: j.created_at,
reverse=True,
)
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
to_remove.extend([j.id for j in overflow])
removed = 0
for jid in set(to_remove):
if jid in self._jobs:
del self._jobs[jid]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
async def _cleanup_loop(self):
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
while self._started:
try:
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
except Exception as e:
logger.warning(f"Job queue cleanup loop error: {e}")
await asyncio.sleep(interval)
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
"""Find all pipeline jobs for a given dataset_id."""
return [
j for j in self._jobs.values()
if j.job_type in PIPELINE_JOB_TYPES
and j.params.get("dataset_id") == dataset_id
] ]
for jid in to_remove:
del self._jobs[jid]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old jobs")
async def _worker(self, worker_id: int): async def _worker(self, worker_id: int):
"""Worker loop: pull jobs from queue and execute handlers."""
logger.info(f"Worker {worker_id} started") logger.info(f"Worker {worker_id} started")
while self._started: while self._started:
try: try:
@@ -220,7 +280,10 @@ class JobQueue:
job.status = JobStatus.RUNNING job.status = JobStatus.RUNNING
job.started_at = time.time() job.started_at = time.time()
if job.progress <= 0:
job.progress = 5.0
job.message = "Running..." job.message = "Running..."
await _sync_processing_task(job)
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try: try:
@@ -231,38 +294,111 @@ class JobQueue:
job.result = result job.result = result
job.message = "Completed" job.message = "Completed"
job.completed_at = time.time() job.completed_at = time.time()
logger.info( logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
f"Worker {worker_id}: completed {job.id} "
f"in {job.elapsed_ms}ms"
)
except Exception as e: except Exception as e:
if not job.is_cancelled: if not job.is_cancelled:
job.status = JobStatus.FAILED job.status = JobStatus.FAILED
job.error = str(e) job.error = str(e)
job.message = f"Failed: {e}" job.message = f"Failed: {e}"
job.completed_at = time.time() job.completed_at = time.time()
logger.error( logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
f"Worker {worker_id}: failed {job.id}: {e}",
exc_info=True, if job.is_cancelled and not job.completed_at:
) job.completed_at = time.time()
await _sync_processing_task(job)
# Fire completion callbacks
for cb in self._completion_callbacks:
try:
await cb(job)
except Exception as cb_err:
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
# Singleton + job handlers async def _sync_processing_task(job: Job):
"""Persist latest job state into processing_tasks (if linked by job_id)."""
from datetime import datetime, timezone
from sqlalchemy import update
job_queue = JobQueue(max_workers=3) try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
values = {
"status": job.status.value,
"progress": float(job.progress),
"message": job.message,
"error": job.error,
}
if job.started_at:
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
if job.completed_at:
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
async with async_session_factory() as db:
await db.execute(
update(ProcessingTask)
.where(ProcessingTask.job_id == job.id)
.values(**values)
)
await db.commit()
except Exception as e:
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
# -- Singleton + job handlers --
job_queue = JobQueue(max_workers=5)
async def _handle_triage(job: Job): async def _handle_triage(job: Job):
"""Triage handler.""" """Triage handler - chains HOST_PROFILE after completion."""
from app.services.triage import triage_dataset from app.services.triage import triage_dataset
dataset_id = job.params.get("dataset_id") dataset_id = job.params.get("dataset_id")
job.message = f"Triaging dataset {dataset_id}" job.message = f"Triaging dataset {dataset_id}"
results = await triage_dataset(dataset_id) await triage_dataset(dataset_id)
return {"count": len(results) if results else 0}
# Chain: trigger host profiling now that triage results exist
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import select
try:
async with async_session_factory() as db:
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
row = ds.first()
hunt_id = row[0] if row else None
if hunt_id:
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
try:
from sqlalchemy import select
from app.db.models import ProcessingTask
async with async_session_factory() as db:
existing = await db.execute(
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
)
if existing.first() is None:
db.add(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset_id,
job_id=hp_job.id,
stage="host_profile",
status="queued",
progress=0.0,
message="Queued",
))
await db.commit()
except Exception as persist_err:
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
logger.warning(f"Failed to chain host profile after triage: {e}")
return {"dataset_id": dataset_id}
async def _handle_host_profile(job: Job): async def _handle_host_profile(job: Job):
"""Host profiling handler."""
from app.services.host_profiler import profile_all_hosts, profile_host from app.services.host_profiler import profile_all_hosts, profile_host
hunt_id = job.params.get("hunt_id") hunt_id = job.params.get("hunt_id")
hostname = job.params.get("hostname") hostname = job.params.get("hostname")
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
async def _handle_report(job: Job): async def _handle_report(job: Job):
"""Report generation handler."""
from app.services.report_generator import generate_report from app.services.report_generator import generate_report
hunt_id = job.params.get("hunt_id") hunt_id = job.params.get("hunt_id")
job.message = f"Generating report for hunt {hunt_id}" job.message = f"Generating report for hunt {hunt_id}"
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
async def _handle_anomaly(job: Job): async def _handle_anomaly(job: Job):
"""Anomaly detection handler."""
from app.services.anomaly_detector import detect_anomalies from app.services.anomaly_detector import detect_anomalies
dataset_id = job.params.get("dataset_id") dataset_id = job.params.get("dataset_id")
k = job.params.get("k", 3) k = job.params.get("k", 3)
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
async def _handle_query(job: Job): async def _handle_query(job: Job):
"""Data query handler (non-streaming)."""
from app.services.data_query import query_dataset from app.services.data_query import query_dataset
dataset_id = job.params.get("dataset_id") dataset_id = job.params.get("dataset_id")
question = job.params.get("question", "") question = job.params.get("question", "")
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
return {"answer": answer} return {"answer": answer}
async def _handle_host_inventory(job: Job):
from app.db import async_session_factory
from app.services.host_inventory import build_host_inventory, inventory_cache
hunt_id = job.params.get("hunt_id")
if not hunt_id:
raise ValueError("hunt_id required")
inventory_cache.set_building(hunt_id)
job.message = f"Building host inventory for hunt {hunt_id}"
try:
async with async_session_factory() as db:
result = await build_host_inventory(hunt_id, db)
inventory_cache.put(hunt_id, result)
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
except Exception:
inventory_cache.clear_building(hunt_id)
raise
async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner, keyword_scan_cache
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
# Cache dataset-only result for fast API reuse
if dataset_id:
keyword_scan_cache.put(dataset_id, result)
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
async def _handle_ioc_extract(job: Job):
"""IOC extraction handler."""
from app.db import async_session_factory
from app.services.ioc_extractor import extract_iocs_from_dataset
dataset_id = job.params.get("dataset_id")
job.message = f"Extracting IOCs from dataset {dataset_id}"
async with async_session_factory() as db:
iocs = await extract_iocs_from_dataset(dataset_id, db)
total = sum(len(v) for v in iocs.values())
job.message = f"IOC extraction complete: {total} IOCs found"
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
async def _on_pipeline_job_complete(job: Job):
"""Update Dataset.processing_status when all pipeline jobs finish."""
if job.job_type not in PIPELINE_JOB_TYPES:
return
dataset_id = job.params.get("dataset_id")
if not dataset_id:
return
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
if not pipeline_jobs:
return
all_done = all(
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
for j in pipeline_jobs
)
if not all_done:
return
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
new_status = "completed_with_errors" if any_failed else "completed"
try:
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import update
async with async_session_factory() as db:
await db.execute(
update(Dataset)
.where(Dataset.id == dataset_id)
.values(processing_status=new_status)
)
await db.commit()
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
except Exception as e:
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
async def reconcile_stale_processing_tasks() -> int:
"""Mark queued/running processing tasks from prior runs as failed."""
from datetime import datetime, timezone
from sqlalchemy import update
try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
now = datetime.now(timezone.utc)
async with async_session_factory() as db:
result = await db.execute(
update(ProcessingTask)
.where(ProcessingTask.status.in_(["queued", "running"]))
.values(
status="failed",
error="Recovered after service restart before task completion",
message="Recovered stale task after restart",
completed_at=now,
)
)
await db.commit()
updated = int(result.rowcount or 0)
if updated:
logger.warning(
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
updated,
)
return updated
except Exception as e:
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
return 0
def register_all_handlers(): def register_all_handlers():
"""Register all job handlers.""" """Register all job handlers and completion callbacks."""
job_queue.register_handler(JobType.TRIAGE, _handle_triage) job_queue.register_handler(JobType.TRIAGE, _handle_triage)
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile) job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
job_queue.register_handler(JobType.REPORT, _handle_report) job_queue.register_handler(JobType.REPORT, _handle_report)
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly) job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
job_queue.register_handler(JobType.QUERY, _handle_query) job_queue.register_handler(JobType.QUERY, _handle_query)
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
job_queue.on_completion(_on_pipeline_job_complete)

View File

@@ -1,4 +1,4 @@
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and """AUP Keyword Scanner searches dataset rows, hunts, annotations, and
messages for keyword matches. messages for keyword matches.
Scanning is done in Python (not SQL LIKE on JSON columns) for portability Scanning is done in Python (not SQL LIKE on JSON columns) for portability
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
import logging import logging
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone
from sqlalchemy import select, func from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db.models import ( from app.db.models import (
KeywordTheme, KeywordTheme,
Keyword,
DatasetRow, DatasetRow,
Dataset, Dataset,
Hunt, Hunt,
Annotation, Annotation,
Message, Message,
Conversation,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
BATCH_SIZE = 500 BATCH_SIZE = 200
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
"""Best-effort extraction of hostname and user from a dataset row."""
if not data:
return None, None
host_keys = (
'hostname', 'host_name', 'host', 'computer_name', 'computer',
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
)
user_keys = (
'username', 'user_name', 'user', 'account_name',
'logged_in_user', 'samaccountname', 'sam_account_name',
)
def pick(keys):
for k in keys:
for actual_key, v in data.items():
if actual_key.lower() == k and v not in (None, ''):
return str(v)
return None
return pick(host_keys), pick(user_keys)
@dataclass @dataclass
@@ -39,6 +64,8 @@ class ScanHit:
matched_value: str matched_value: str
row_index: int | None = None row_index: int | None = None
dataset_name: str | None = None dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
@dataclass @dataclass
@@ -50,21 +77,54 @@ class ScanResult:
rows_scanned: int = 0 rows_scanned: int = 0
@dataclass
class KeywordScanCacheEntry:
dataset_id: str
result: dict
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
class KeywordScanCache:
"""In-memory per-dataset cache for dataset-only keyword scans.
This enables fast-path reads when users run AUP scans against datasets that
were already scanned during upload pipeline processing.
"""
def __init__(self):
self._entries: dict[str, KeywordScanCacheEntry] = {}
def put(self, dataset_id: str, result: dict):
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
return self._entries.get(dataset_id)
def invalidate_dataset(self, dataset_id: str):
self._entries.pop(dataset_id, None)
def clear(self):
self._entries.clear()
keyword_scan_cache = KeywordScanCache()
class KeywordScanner: class KeywordScanner:
"""Scans multiple data sources for keyword/regex matches.""" """Scans multiple data sources for keyword/regex matches."""
def __init__(self, db: AsyncSession): def __init__(self, db: AsyncSession):
self.db = db self.db = db
# ── Public API ──────────────────────────────────────────────────── # Public API
async def scan( async def scan(
self, self,
dataset_ids: list[str] | None = None, dataset_ids: list[str] | None = None,
theme_ids: list[str] | None = None, theme_ids: list[str] | None = None,
scan_hunts: bool = True, scan_hunts: bool = False,
scan_annotations: bool = True, scan_annotations: bool = False,
scan_messages: bool = True, scan_messages: bool = False,
) -> dict: ) -> dict:
"""Run a full AUP scan and return dict matching ScanResponse.""" """Run a full AUP scan and return dict matching ScanResponse."""
# Load themes + keywords # Load themes + keywords
@@ -103,7 +163,7 @@ class KeywordScanner:
"rows_scanned": result.rows_scanned, "rows_scanned": result.rows_scanned,
} }
# ── Internal ────────────────────────────────────────────────────── # Internal
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]: async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712 q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
@@ -143,6 +203,8 @@ class KeywordScanner:
hits: list[ScanHit], hits: list[ScanHit],
row_index: int | None = None, row_index: int | None = None,
dataset_name: str | None = None, dataset_name: str | None = None,
hostname: str | None = None,
username: str | None = None,
) -> None: ) -> None:
"""Check text against all compiled patterns, append hits.""" """Check text against all compiled patterns, append hits."""
if not text: if not text:
@@ -150,8 +212,7 @@ class KeywordScanner:
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items(): for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
for kw_value, pat in keyword_patterns: for kw_value, pat in keyword_patterns:
if pat.search(text): if pat.search(text):
# Truncate matched_value for display matched_preview = text[:200] + ("" if len(text) > 200 else "")
matched_preview = text[:200] + ("" if len(text) > 200 else "")
hits.append(ScanHit( hits.append(ScanHit(
theme_name=theme_name, theme_name=theme_name,
theme_color=theme_color, theme_color=theme_color,
@@ -162,13 +223,14 @@ class KeywordScanner:
matched_value=matched_preview, matched_value=matched_preview,
row_index=row_index, row_index=row_index,
dataset_name=dataset_name, dataset_name=dataset_name,
hostname=hostname,
username=username,
)) ))
async def _scan_datasets( async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None: ) -> None:
"""Scan dataset rows in batches.""" """Scan dataset rows in batches using keyset pagination (no OFFSET)."""
# Build dataset name lookup
ds_q = select(Dataset.id, Dataset.name) ds_q = select(Dataset.id, Dataset.name)
if dataset_ids: if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids)) ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
@@ -178,37 +240,66 @@ class KeywordScanner:
if not ds_map: if not ds_map:
return return
# Iterate rows in batches import asyncio
offset = 0
row_q_base = select(DatasetRow).where(
DatasetRow.dataset_id.in_(list(ds_map.keys()))
).order_by(DatasetRow.id)
while True: max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
rows_result = await self.db.execute( budget_reached = False
row_q_base.offset(offset).limit(BATCH_SIZE)
for ds_id, ds_name in ds_map.items():
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
last_id = 0
while True:
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
rows_result = await self.db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds_id)
.where(DatasetRow.id > last_id)
.order_by(DatasetRow.id)
.limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
hostname, username = _infer_hostname_and_user(data)
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
hostname=hostname,
username=username,
)
last_id = rows[-1].id
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
if budget_reached:
break
if budget_reached:
logger.warning(
"AUP scan row budget reached (%d rows). Returning partial results.",
result.rows_scanned,
) )
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text, patterns, "dataset_row", row.id,
col_name, result.hits,
row_index=row.row_index,
dataset_name=ds_map.get(row.dataset_id),
)
offset += BATCH_SIZE
if len(rows) < BATCH_SIZE:
break
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None: async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
"""Scan hunt names and descriptions.""" """Scan hunt names and descriptions."""

View File

@@ -1,4 +1,4 @@
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner.""" """Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
from __future__ import annotations from __future__ import annotations
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
ARTIFACT_FOCUS = { ARTIFACT_FOCUS = {
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
rows_result = await db.execute( rows_result = await db.execute(
select(DatasetRow) select(DatasetRow)
.where(DatasetRow.dataset_id == dataset_id) .where(DatasetRow.dataset_id == dataset_id)
.order_by(DatasetRow.row_number) .order_by(DatasetRow.row_index)
.offset(offset) .offset(offset)
.limit(batch_size) .limit(batch_size)
) )
@@ -167,4 +167,4 @@ Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."
offset += batch_size offset += batch_size
logger.info("Triage complete for dataset %s", dataset_id) logger.info("Triage complete for dataset %s", dataset_id)

View File

@@ -0,0 +1,124 @@
"""Tests for execution-mode behavior in /api/agent/assist."""
import io
import pytest
@pytest.mark.asyncio
async def test_agent_assist_policy_query_executes_scan(client):
# 1) Create hunt
h = await client.post("/api/hunts", json={"name": "Policy Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
# 2) Upload browser-history-like CSV
csv_bytes = (
b"User,visited_url,title,ClientId,Fqdn\n"
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n"
)
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
# 3) Ensure policy theme/keyword exists
t = await client.post(
"/api/keywords/themes",
json={
"name": "Adult Content",
"color": "#e91e63",
"enabled": True,
},
)
assert t.status_code in (201, 409)
themes = await client.get("/api/keywords/themes")
assert themes.status_code == 200
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
k = await client.post(
f"/api/keywords/themes/{adult['id']}/keywords",
json={"value": "pornhub", "is_regex": False},
)
assert k.status_code in (201, 409)
# 4) Execution-mode query
q = await client.post(
"/api/agent/assist",
json={
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
"hunt_id": hunt_id,
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] == "execution:keyword_scanner"
assert body["execution"] is not None
assert body["execution"]["policy_hits"] >= 1
assert len(body["execution"]["top_user_hosts"]) >= 1
@pytest.mark.asyncio
async def test_agent_assist_execution_preference_off_stays_advisory(client):
h = await client.post("/api/hunts", json={"name": "No Exec Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
q = await client.post(
"/api/agent/assist",
json={
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
"hunt_id": hunt_id,
"execution_preference": "off",
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] != "execution:keyword_scanner"
assert body["execution"] is None
@pytest.mark.asyncio
async def test_agent_assist_execution_preference_force_executes(client):
# Create hunt + dataset even when the query text is not policy-specific
h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
csv_bytes = (
b"User,visited_url,title,ClientId,Fqdn\n"
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
)
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
t = await client.post(
"/api/keywords/themes",
json={"name": "Adult Content", "color": "#e91e63", "enabled": True},
)
assert t.status_code in (201, 409)
themes = await client.get("/api/keywords/themes")
assert themes.status_code == 200
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
k = await client.post(
f"/api/keywords/themes/{adult['id']}/keywords",
json={"value": "pornhub", "is_regex": False},
)
assert k.status_code in (201, 409)
q = await client.post(
"/api/agent/assist",
json={
"query": "Summarize notable activity in this hunt.",
"hunt_id": hunt_id,
"execution_preference": "force",
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] == "execution:keyword_scanner"
assert body["execution"] is not None

View File

@@ -77,6 +77,26 @@ class TestHuntEndpoints:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_hunt_progress(self, client):
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
hunt_id = create.json()["id"]
# attach one dataset so progress has scope
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
res = await client.get(f"/api/hunts/{hunt_id}/progress")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert "progress_percent" in body
assert "dataset_total" in body
assert "network_status" in body
@pytest.mark.asyncio @pytest.mark.asyncio
class TestDatasetEndpoints: class TestDatasetEndpoints:
"""Test dataset upload and retrieval.""" """Test dataset upload and retrieval."""

View File

@@ -1,4 +1,4 @@
"""Tests for CSV parser and normalizer services.""" """Tests for CSV parser and normalizer services."""
import pytest import pytest
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
@@ -43,8 +43,9 @@ class TestCSVParser:
assert len(rows) == 2 assert len(rows) == 2
def test_parse_empty_file(self): def test_parse_empty_file(self):
with pytest.raises(Exception): rows, meta = parse_csv_bytes(b"")
parse_csv_bytes(b"") assert len(rows) == 0
assert meta["row_count"] == 0
def test_detect_encoding_utf8(self): def test_detect_encoding_utf8(self):
enc = detect_encoding(SAMPLE_CSV) enc = detect_encoding(SAMPLE_CSV)
@@ -53,17 +54,15 @@ class TestCSVParser:
def test_infer_column_types(self): def test_infer_column_types(self):
types = infer_column_types( types = infer_column_types(
["192.168.1.1", "10.0.0.1", "8.8.8.8"], [{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}],
"src_ip",
) )
assert types == "ip" assert types["src_ip"] == "ip"
def test_infer_column_types_hash(self): def test_infer_column_types_hash(self):
types = infer_column_types( types = infer_column_types(
["d41d8cd98f00b204e9800998ecf8427e"], [{"hash": "d41d8cd98f00b204e9800998ecf8427e"}],
"hash",
) )
assert types == "hash_md5" assert types["hash"] == "hash_md5"
class TestNormalizer: class TestNormalizer:
@@ -94,7 +93,7 @@ class TestNormalizer:
start, end = detect_time_range(rows, column_mapping) start, end = detect_time_range(rows, column_mapping)
# Should detect time range from timestamp column # Should detect time range from timestamp column
if start: if start:
assert "2025" in start assert "2025" in str(start)
def test_normalize_rows(self): def test_normalize_rows(self):
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}] rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
@@ -102,3 +101,6 @@ class TestNormalizer:
normalized = normalize_rows(rows, mapping) normalized = normalize_rows(rows, mapping)
assert len(normalized) == 1 assert len(normalized) == 1
assert normalized[0].get("src_ip") == "10.0.0.1" assert normalized[0].get("src_ip") == "10.0.0.1"

View File

@@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient):
assert "total_hits" in data assert "total_hits" in data
# powershell should match at least one row # powershell should match at least one row
assert data["total_hits"] > 0 assert data["total_hits"] > 0
@pytest.mark.asyncio
async def test_quick_scan_cache_hit(client: AsyncClient):
"""Second quick scan should return cache hit metadata."""
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
tid = theme_res.json()["id"]
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
upload = await client.post("/api/datasets/upload", files=files)
ds_id = upload.json()["id"]
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert first.status_code == 200
assert first.json().get("cache_status") in ("miss", "hit")
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert second.status_code == 200
body = second.json()
assert body.get("cache_used") is True
assert body.get("cache_status") == "hit"

View File

@@ -0,0 +1,84 @@
"""Tests for network inventory endpoints and cache/polling behavior."""
import io
import pytest
from app.services.host_inventory import inventory_cache
from tests.conftest import SAMPLE_CSV
@pytest.mark.asyncio
async def test_inventory_status_none_for_unknown_hunt(client):
hunt_id = "hunt-does-not-exist"
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert body["status"] == "none"
@pytest.mark.asyncio
async def test_host_inventory_cold_cache_returns_202(client):
# Create hunt and upload dataset linked to that hunt
hunt = await client.post("/api/hunts", json={"name": "Net Hunt"})
hunt_id = hunt.json()["id"]
files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id})
assert up.status_code == 200
# Ensure cache is cold for this hunt
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 202
body = res.json()
assert body["status"] == "building"
@pytest.mark.asyncio
async def test_host_inventory_ready_cache_returns_200(client):
hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"})
hunt_id = hunt.json()["id"]
mock_inventory = {
"hosts": [
{
"id": "host-1",
"hostname": "HOST-1",
"fqdn": "HOST-1.local",
"client_id": "C.1234abcd",
"ips": ["10.0.0.10"],
"os": "Windows 10",
"users": ["alice"],
"datasets": ["test"],
"row_count": 5,
}
],
"connections": [],
"stats": {
"total_hosts": 1,
"hosts_with_ips": 1,
"hosts_with_users": 1,
"total_datasets_scanned": 1,
"total_rows_scanned": 5,
},
}
inventory_cache.put(hunt_id, mock_inventory)
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 200
body = res.json()
assert body["stats"]["total_hosts"] == 1
assert len(body["hosts"]) == 1
assert body["hosts"][0]["hostname"] == "HOST-1"
status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
assert status_res.status_code == 200
assert status_res.json()["status"] == "ready"

View File

@@ -0,0 +1,82 @@
"""Scale-oriented network endpoint tests (summary/subgraph/backpressure)."""
import pytest
from app.config import settings
from app.services.host_inventory import inventory_cache
@pytest.mark.asyncio
async def test_network_summary_from_cache(client):
hunt_id = "scale-hunt-summary"
inv = {
"hosts": [
{"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50},
{"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10},
],
"connections": [
{"source": "h1", "target": "8.8.8.8", "count": 7},
{"source": "h1", "target": "h2", "count": 3},
],
"stats": {"total_hosts": 2, "total_rows_scanned": 60},
}
inventory_cache.put(hunt_id, inv)
res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1")
assert res.status_code == 200
body = res.json()
assert body["stats"]["total_hosts"] == 2
assert len(body["top_hosts"]) == 1
assert body["top_hosts"][0]["id"] == "h1"
@pytest.mark.asyncio
async def test_network_subgraph_truncates(client):
hunt_id = "scale-hunt-subgraph"
inv = {
"hosts": [
{"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i}
for i in range(1, 8)
],
"connections": [
{"source": "h1", "target": "h2", "count": 20},
{"source": "h1", "target": "h3", "count": 15},
{"source": "h2", "target": "h4", "count": 5},
{"source": "h3", "target": "h5", "count": 4},
],
"stats": {"total_hosts": 7, "total_rows_scanned": 999},
}
inventory_cache.put(hunt_id, inv)
res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2")
assert res.status_code == 200
body = res.json()
assert len(body["hosts"]) <= 3
assert len(body["connections"]) <= 2
assert body["stats"]["truncated"] is True
@pytest.mark.asyncio
async def test_manual_job_submit_backpressure_returns_429(client):
old = settings.JOB_QUEUE_MAX_BACKLOG
settings.JOB_QUEUE_MAX_BACKLOG = 0
try:
res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}})
assert res.status_code == 429
finally:
settings.JOB_QUEUE_MAX_BACKLOG = old
@pytest.mark.asyncio
async def test_network_host_inventory_deferred_when_queue_backlogged(client):
hunt_id = "deferred-hunt"
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
old = settings.JOB_QUEUE_MAX_BACKLOG
settings.JOB_QUEUE_MAX_BACKLOG = 0
try:
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 202
body = res.json()
assert body["status"] == "deferred"
finally:
settings.JOB_QUEUE_MAX_BACKLOG = old

View File

@@ -0,0 +1,203 @@
"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches."""
import pytest
import pytest_asyncio
class TestMitreRoutes:
"""Tests for /api/mitre endpoints."""
@pytest.mark.asyncio
async def test_mitre_coverage_empty(self, client):
resp = await client.get("/api/mitre/coverage")
assert resp.status_code == 200
data = resp.json()
assert "tactics" in data
assert "technique_count" in data
assert data["technique_count"] == 0
assert len(data["tactics"]) == 14 # 14 MITRE tactics
@pytest.mark.asyncio
async def test_mitre_coverage_with_hunt_filter(self, client):
resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent")
assert resp.status_code == 200
assert resp.json()["technique_count"] == 0
class TestTimelineRoutes:
"""Tests for /api/timeline endpoints."""
@pytest.mark.asyncio
async def test_timeline_hunt_not_found(self, client):
resp = await client.get("/api/timeline/hunt/nonexistent")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_timeline_with_hunt(self, client):
# Create a hunt first
hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"})
assert hunt_resp.status_code in (200, 201)
hunt_id = hunt_resp.json()["id"]
resp = await client.get(f"/api/timeline/hunt/{hunt_id}")
assert resp.status_code == 200
data = resp.json()
assert data["hunt_id"] == hunt_id
assert "events" in data
assert "datasets" in data
class TestPlaybookRoutes:
"""Tests for /api/playbooks endpoints."""
@pytest.mark.asyncio
async def test_list_playbooks_empty(self, client):
resp = await client.get("/api/playbooks")
assert resp.status_code == 200
assert resp.json()["playbooks"] == []
@pytest.mark.asyncio
async def test_get_templates(self, client):
resp = await client.get("/api/playbooks/templates")
assert resp.status_code == 200
templates = resp.json()["templates"]
assert len(templates) >= 2
assert templates[0]["name"] == "Standard Threat Hunt"
@pytest.mark.asyncio
async def test_create_playbook(self, client):
resp = await client.post("/api/playbooks", json={
"name": "My Investigation",
"description": "Test playbook",
"steps": [
{"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"},
{"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"},
],
})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "My Investigation"
assert len(data["steps"]) == 2
@pytest.mark.asyncio
async def test_playbook_crud(self, client):
# Create
resp = await client.post("/api/playbooks", json={
"name": "CRUD Test",
"steps": [{"title": "Do something"}],
})
assert resp.status_code == 201
pb_id = resp.json()["id"]
# Get
resp = await client.get(f"/api/playbooks/{pb_id}")
assert resp.status_code == 200
assert resp.json()["name"] == "CRUD Test"
assert len(resp.json()["steps"]) == 1
# Update
resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"})
assert resp.status_code == 200
# Delete
resp = await client.delete(f"/api/playbooks/{pb_id}")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_playbook_step_completion(self, client):
# Create with step
resp = await client.post("/api/playbooks", json={
"name": "Step Test",
"steps": [{"title": "Task 1"}],
})
pb_id = resp.json()["id"]
# Get to find step ID
resp = await client.get(f"/api/playbooks/{pb_id}")
steps = resp.json()["steps"]
step_id = steps[0]["id"]
assert steps[0]["is_completed"] is False
# Mark complete
resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"})
assert resp.status_code == 200
assert resp.json()["is_completed"] is True
class TestSavedSearchRoutes:
"""Tests for /api/searches endpoints."""
@pytest.mark.asyncio
async def test_list_empty(self, client):
resp = await client.get("/api/searches")
assert resp.status_code == 200
assert resp.json()["searches"] == []
@pytest.mark.asyncio
async def test_create_saved_search(self, client):
resp = await client.post("/api/searches", json={
"name": "Suspicious IPs",
"search_type": "ioc_search",
"query_params": {"ioc_value": "203.0.113"},
})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Suspicious IPs"
assert data["search_type"] == "ioc_search"
@pytest.mark.asyncio
async def test_search_crud(self, client):
# Create
resp = await client.post("/api/searches", json={
"name": "Test Query",
"search_type": "keyword_scan",
"query_params": {"theme": "malware"},
})
s_id = resp.json()["id"]
# Get
resp = await client.get(f"/api/searches/{s_id}")
assert resp.status_code == 200
# Update
resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"})
assert resp.status_code == 200
# Run
resp = await client.post(f"/api/searches/{s_id}/run")
assert resp.status_code == 200
data = resp.json()
assert "result_count" in data
assert "delta" in data
# Delete
resp = await client.delete(f"/api/searches/{s_id}")
assert resp.status_code == 200
class TestStixExport:
"""Tests for /api/export/stix endpoints."""
@pytest.mark.asyncio
async def test_stix_export_hunt_not_found(self, client):
resp = await client.get("/api/export/stix/nonexistent-id")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_stix_export_empty_hunt(self, client):
"""Export from a real hunt with no data returns valid but minimal bundle."""
hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"})
assert hunt_resp.status_code in (200, 201)
hunt_id = hunt_resp.json()["id"]
resp = await client.get(f"/api/export/stix/{hunt_id}")
assert resp.status_code == 200
data = resp.json()
assert data["type"] == "bundle"
assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle
assert "objects" in data
# At minimum should have the identity object
types = [o["type"] for o in data["objects"]]
assert "identity" in types

Binary file not shown.

View File

@@ -7,24 +7,24 @@ services:
ports: ports:
- "8000:8000" - "8000:8000"
environment: environment:
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ── # ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
TH_WILE_HOST: "100.110.190.12" TH_WILE_HOST: "100.110.190.12"
TH_ROADRUNNER_HOST: "100.110.190.11" TH_ROADRUNNER_HOST: "100.110.190.11"
TH_OLLAMA_PORT: "11434" TH_OLLAMA_PORT: "11434"
TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer" TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer"
# ── Database ── # ── Database ──
TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db" TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db"
# ── Auth ── # ── Auth ──
TH_JWT_SECRET: "change-me-in-production" TH_JWT_SECRET: "change-me-in-production"
# ── Enrichment API keys (set your own) ── # ── Enrichment API keys (set your own) ──
# TH_VIRUSTOTAL_API_KEY: "" # TH_VIRUSTOTAL_API_KEY: ""
# TH_ABUSEIPDB_API_KEY: "" # TH_ABUSEIPDB_API_KEY: ""
# TH_SHODAN_API_KEY: "" # TH_SHODAN_API_KEY: ""
# ── Agent behaviour ── # ── Agent behaviour ──
TH_AGENT_MAX_TOKENS: "4096" TH_AGENT_MAX_TOKENS: "4096"
TH_AGENT_TEMPERATURE: "0.3" TH_AGENT_TEMPERATURE: "0.3"
volumes: volumes:
@@ -51,7 +51,7 @@ services:
networks: networks:
- threathunt - threathunt
healthcheck: healthcheck:
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:3000/"] test: ["CMD", "curl", "-f", "http://127.0.0.1:3000/"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3

350
fix_all.py Normal file
View File

@@ -0,0 +1,350 @@
"""Fix all critical issues: DB locking, keyword scan, network map."""
import os, re
ROOT = r"D:\Projects\Dev\ThreatHunt"
def fix_file(filepath, replacements):
"""Apply text replacements to a file."""
path = os.path.join(ROOT, filepath)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
for old, new, desc in replacements:
if old in content:
content = content.replace(old, new, 1)
print(f" OK: {desc}")
else:
print(f" SKIP: {desc} (pattern not found)")
with open(path, "w", encoding="utf-8") as f:
f.write(content)
return content
# ================================================================
# FIX 1: Database engine - NullPool instead of StaticPool
# ================================================================
print("\n=== FIX 1: Database engine (NullPool + higher timeouts) ===")
engine_path = os.path.join(ROOT, "backend", "app", "db", "engine.py")
with open(engine_path, "r", encoding="utf-8") as f:
engine_content = f.read()
new_engine = '''"""Database engine, session factory, and base model.
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
"""
from sqlalchemy import event
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
_engine_kwargs: dict = dict(
echo=settings.DEBUG,
future=True,
)
if _is_sqlite:
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
# NullPool: each session gets its own connection.
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
from sqlalchemy.pool import NullPool
_engine_kwargs["poolclass"] = NullPool
else:
_engine_kwargs["pool_size"] = 5
_engine_kwargs["max_overflow"] = 10
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
@event.listens_for(engine.sync_engine, "connect")
def _set_sqlite_pragmas(dbapi_conn, connection_record):
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
if _is_sqlite:
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA busy_timeout=30000")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
# Alias expected by other modules
async_session = async_session_factory
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
async def get_db() -> AsyncSession: # type: ignore[misc]
"""FastAPI dependency that yields an async DB session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db() -> None:
"""Create all tables (for dev / first-run). In production use Alembic."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def dispose_db() -> None:
"""Dispose of the engine on shutdown."""
await engine.dispose()
'''
with open(engine_path, "w", encoding="utf-8") as f:
f.write(new_engine)
print(" OK: Replaced StaticPool with NullPool")
print(" OK: Increased busy_timeout 5000 -> 30000ms")
print(" OK: Added check_same_thread=False")
print(" OK: Connection timeout 30 -> 60s")
# ================================================================
# FIX 2: Keyword scan endpoint - make POST non-blocking (background job)
# ================================================================
print("\n=== FIX 2: Keyword scan endpoint -> background job ===")
kw_path = os.path.join(ROOT, "backend", "app", "api", "routes", "keywords.py")
with open(kw_path, "r", encoding="utf-8") as f:
kw_content = f.read()
# Replace the scan endpoint to be non-blocking
old_scan = '''@router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
"""Run AUP keyword scan across selected data sources."""
scanner = KeywordScanner(db)
result = await scanner.scan(
dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids,
scan_hunts=body.scan_hunts,
scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages,
)
return result'''
new_scan = '''@router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
"""Run AUP keyword scan across selected data sources.
Uses a dedicated DB session separate from the request session
to avoid blocking other API requests on SQLite.
"""
from app.db import async_session_factory
async with async_session_factory() as scan_db:
scanner = KeywordScanner(scan_db)
result = await scanner.scan(
dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids,
scan_hunts=body.scan_hunts,
scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages,
)
return result'''
if old_scan in kw_content:
kw_content = kw_content.replace(old_scan, new_scan, 1)
print(" OK: Scan endpoint uses dedicated DB session")
else:
print(" SKIP: Scan endpoint pattern not found")
# Also fix quick_scan
old_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db),
):
"""Quick scan a single dataset with all enabled themes."""
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
return result'''
new_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db),
):
"""Quick scan a single dataset with all enabled themes."""
from app.db import async_session_factory
async with async_session_factory() as scan_db:
scanner = KeywordScanner(scan_db)
result = await scanner.scan(dataset_ids=[dataset_id])
return result'''
if old_quick in kw_content:
kw_content = kw_content.replace(old_quick, new_quick, 1)
print(" OK: Quick scan uses dedicated DB session")
else:
print(" SKIP: Quick scan pattern not found")
with open(kw_path, "w", encoding="utf-8") as f:
f.write(kw_content)
# ================================================================
# FIX 3: Scanner service - smaller batches, yield between batches
# ================================================================
print("\n=== FIX 3: Scanner service - smaller batches + async yield ===")
scanner_path = os.path.join(ROOT, "backend", "app", "services", "scanner.py")
with open(scanner_path, "r", encoding="utf-8") as f:
scanner_content = f.read()
# Change batch size and add yield between batches
old_batch = "BATCH_SIZE = 500"
new_batch = "BATCH_SIZE = 200"
if old_batch in scanner_content:
scanner_content = scanner_content.replace(old_batch, new_batch, 1)
print(" OK: Reduced batch size 500 -> 200")
# Add asyncio.sleep(0) between batches to yield to other tasks
old_batch_loop = ''' offset += BATCH_SIZE
if len(rows) < BATCH_SIZE:
break'''
new_batch_loop = ''' offset += BATCH_SIZE
# Yield to event loop between batches so other requests aren't starved
import asyncio
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break'''
if old_batch_loop in scanner_content:
scanner_content = scanner_content.replace(old_batch_loop, new_batch_loop, 1)
print(" OK: Added async yield between scan batches")
else:
print(" SKIP: Batch loop pattern not found")
with open(scanner_path, "w", encoding="utf-8") as f:
f.write(scanner_content)
# ================================================================
# FIX 4: Job queue workers - increase from 3 to 5
# ================================================================
print("\n=== FIX 4: Job queue - more workers ===")
jq_path = os.path.join(ROOT, "backend", "app", "services", "job_queue.py")
with open(jq_path, "r", encoding="utf-8") as f:
jq_content = f.read()
old_workers = "job_queue = JobQueue(max_workers=3)"
new_workers = "job_queue = JobQueue(max_workers=5)"
if old_workers in jq_content:
jq_content = jq_content.replace(old_workers, new_workers, 1)
print(" OK: Workers 3 -> 5")
with open(jq_path, "w", encoding="utf-8") as f:
f.write(jq_content)
# ================================================================
# FIX 5: main.py - always re-run pipeline on startup for ALL datasets
# ================================================================
print("\n=== FIX 5: Startup reprocessing - all datasets, not just 'ready' ===")
main_path = os.path.join(ROOT, "backend", "app", "main.py")
with open(main_path, "r", encoding="utf-8") as f:
main_content = f.read()
# The current startup only reprocesses datasets with status="ready"
# But after previous runs, they're all "completed" - so nothing happens
# Fix: reprocess datasets that have NO triage/anomaly results in DB
old_reprocess = ''' # Reprocess datasets that were never fully processed (status still "ready")
async with async_session_factory() as reprocess_db:
from sqlalchemy import select
from app.db.models import Dataset
stmt = select(Dataset.id).where(Dataset.processing_status == "ready")
result = await reprocess_db.execute(stmt)
unprocessed_ids = [row[0] for row in result.all()]
for ds_id in unprocessed_ids:
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
if unprocessed_ids:
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
# Mark them as processing
async with async_session_factory() as update_db:
from sqlalchemy import update
from app.db.models import Dataset
await update_db.execute(
update(Dataset)
.where(Dataset.id.in_(unprocessed_ids))
.values(processing_status="processing")
)
await update_db.commit()'''
new_reprocess = ''' # Check which datasets still need processing
# (no anomaly results = never fully processed)
async with async_session_factory() as reprocess_db:
from sqlalchemy import select, exists
from app.db.models import Dataset, AnomalyResult
# Find datasets that have zero anomaly results (pipeline never ran or failed)
has_anomaly = (
select(AnomalyResult.id)
.where(AnomalyResult.dataset_id == Dataset.id)
.limit(1)
.correlate(Dataset)
.exists()
)
stmt = select(Dataset.id).where(~has_anomaly)
result = await reprocess_db.execute(stmt)
unprocessed_ids = [row[0] for row in result.all()]
if unprocessed_ids:
for ds_id in unprocessed_ids:
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
async with async_session_factory() as update_db:
from sqlalchemy import update
from app.db.models import Dataset
await update_db.execute(
update(Dataset)
.where(Dataset.id.in_(unprocessed_ids))
.values(processing_status="processing")
)
await update_db.commit()
else:
logger.info("All datasets already processed - skipping startup pipeline")'''
if old_reprocess in main_content:
main_content = main_content.replace(old_reprocess, new_reprocess, 1)
print(" OK: Startup checks for actual results, not just status field")
else:
print(" SKIP: Reprocess block not found")
with open(main_path, "w", encoding="utf-8") as f:
f.write(main_content)
print("\n=== ALL FIXES APPLIED ===")

30
fix_keywords.py Normal file
View File

@@ -0,0 +1,30 @@
import re
path = "backend/app/api/routes/keywords.py"
with open(path, "r", encoding="utf-8") as f:
c = f.read()
# Fix POST /scan - remove dedicated session, use injected db
old1 = ' """Run AUP keyword scan across selected data sources.\n \n Uses a dedicated DB session separate from the request session\n to avoid blocking other API requests on SQLite.\n """\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
new1 = ' """Run AUP keyword scan across selected data sources."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
if old1 in c:
c = c.replace(old1, new1, 1)
print("OK: reverted POST /scan")
else:
print("SKIP: POST /scan not found")
# Fix GET /scan/quick
old2 = ' """Quick scan a single dataset with all enabled themes."""\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
new2 = ' """Quick scan a single dataset with all enabled themes."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
if old2 in c:
c = c.replace(old2, new2, 1)
print("OK: reverted GET /scan/quick")
else:
print("SKIP: GET /scan/quick not found")
with open(path, "w", encoding="utf-8") as f:
f.write(c)

View File

@@ -16,6 +16,12 @@ server {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme; proxy_set_header X-Forwarded-Proto $scheme;
proxy_read_timeout 300s; proxy_read_timeout 300s;
# SSE streaming support for agent assist
proxy_buffering off;
proxy_cache off;
proxy_set_header Connection '';
chunked_transfer_encoding off;
} }
# SPA fallback serve index.html for all non-file routes # SPA fallback serve index.html for all non-file routes

View File

@@ -24,9 +24,13 @@
"react-markdown": "^10.1.0", "react-markdown": "^10.1.0",
"react-router-dom": "^7.13.0", "react-router-dom": "^7.13.0",
"react-scripts": "5.0.1", "react-scripts": "5.0.1",
<<<<<<< HEAD
"recharts": "^3.7.0"
=======
"recharts": "^3.7.0", "recharts": "^3.7.0",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"yaml": "^2.8.2" "yaml": "^2.8.2"
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
}, },
"devDependencies": { "devDependencies": {
"@types/cytoscape": "^3.21.9", "@types/cytoscape": "^3.21.9",
@@ -3978,6 +3982,8 @@
"@types/node": "*" "@types/node": "*"
} }
}, },
<<<<<<< HEAD
=======
"node_modules/@types/cytoscape": { "node_modules/@types/cytoscape": {
"version": "3.21.9", "version": "3.21.9",
"resolved": "https://registry.npmjs.org/@types/cytoscape/-/cytoscape-3.21.9.tgz", "resolved": "https://registry.npmjs.org/@types/cytoscape/-/cytoscape-3.21.9.tgz",
@@ -3985,6 +3991,7 @@
"dev": true, "dev": true,
"license": "MIT" "license": "MIT"
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/@types/d3-array": { "node_modules/@types/d3-array": {
"version": "3.2.2", "version": "3.2.2",
"resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz",
@@ -4048,6 +4055,8 @@
"integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==", "integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==",
"license": "MIT" "license": "MIT"
}, },
<<<<<<< HEAD
=======
"node_modules/@types/debug": { "node_modules/@types/debug": {
"version": "4.1.12", "version": "4.1.12",
"resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz", "resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz",
@@ -4057,6 +4066,7 @@
"@types/ms": "*" "@types/ms": "*"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/@types/eslint": { "node_modules/@types/eslint": {
"version": "8.56.12", "version": "8.56.12",
"resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz", "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz",
@@ -4415,12 +4425,15 @@
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==", "integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
"license": "MIT" "license": "MIT"
}, },
<<<<<<< HEAD
=======
"node_modules/@types/unist": { "node_modules/@types/unist": {
"version": "3.0.3", "version": "3.0.3",
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz", "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz",
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==", "integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
"license": "MIT" "license": "MIT"
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/@types/use-sync-external-store": { "node_modules/@types/use-sync-external-store": {
"version": "0.0.6", "version": "0.0.6",
"resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz", "resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz",
@@ -7026,6 +7039,8 @@
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==", "integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
"license": "MIT" "license": "MIT"
}, },
<<<<<<< HEAD
=======
"node_modules/cytoscape": { "node_modules/cytoscape": {
"version": "3.33.1", "version": "3.33.1",
"resolved": "https://registry.npmjs.org/cytoscape/-/cytoscape-3.33.1.tgz", "resolved": "https://registry.npmjs.org/cytoscape/-/cytoscape-3.33.1.tgz",
@@ -7060,6 +7075,7 @@
"cytoscape": "^3.2.22" "cytoscape": "^3.2.22"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/d3-array": { "node_modules/d3-array": {
"version": "3.2.4", "version": "3.2.4",
"resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz", "resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz",
@@ -7081,6 +7097,8 @@
"node": ">=12" "node": ">=12"
} }
}, },
<<<<<<< HEAD
=======
"node_modules/d3-dispatch": { "node_modules/d3-dispatch": {
"version": "1.0.6", "version": "1.0.6",
"resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-1.0.6.tgz", "resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-1.0.6.tgz",
@@ -7097,6 +7115,7 @@
"d3-selection": "1" "d3-selection": "1"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/d3-ease": { "node_modules/d3-ease": {
"version": "3.0.1", "version": "3.0.1",
"resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz", "resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz",
@@ -7152,12 +7171,15 @@
"node": ">=12" "node": ">=12"
} }
}, },
<<<<<<< HEAD
=======
"node_modules/d3-selection": { "node_modules/d3-selection": {
"version": "1.4.2", "version": "1.4.2",
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-1.4.2.tgz", "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-1.4.2.tgz",
"integrity": "sha512-SJ0BqYihzOjDnnlfyeHT0e30k0K1+5sR3d5fNueCNeuhZTnGw4M4o8mqJchSwgKMXCNFo+e2VTChiSJ0vYtXkg==", "integrity": "sha512-SJ0BqYihzOjDnnlfyeHT0e30k0K1+5sR3d5fNueCNeuhZTnGw4M4o8mqJchSwgKMXCNFo+e2VTChiSJ0vYtXkg==",
"license": "BSD-3-Clause" "license": "BSD-3-Clause"
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/d3-shape": { "node_modules/d3-shape": {
"version": "3.2.0", "version": "3.2.0",
"resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz", "resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz",
@@ -7203,6 +7225,8 @@
"node": ">=12" "node": ">=12"
} }
}, },
<<<<<<< HEAD
=======
"node_modules/dagre": { "node_modules/dagre": {
"version": "0.8.5", "version": "0.8.5",
"resolved": "https://registry.npmjs.org/dagre/-/dagre-0.8.5.tgz", "resolved": "https://registry.npmjs.org/dagre/-/dagre-0.8.5.tgz",
@@ -7213,6 +7237,7 @@
"lodash": "^4.17.15" "lodash": "^4.17.15"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/damerau-levenshtein": { "node_modules/damerau-levenshtein": {
"version": "1.0.8", "version": "1.0.8",
"resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz", "resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz",
@@ -7313,6 +7338,8 @@
"integrity": "sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==", "integrity": "sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==",
"license": "MIT" "license": "MIT"
}, },
<<<<<<< HEAD
=======
"node_modules/decode-named-character-reference": { "node_modules/decode-named-character-reference": {
"version": "1.3.0", "version": "1.3.0",
"resolved": "https://registry.npmjs.org/decode-named-character-reference/-/decode-named-character-reference-1.3.0.tgz", "resolved": "https://registry.npmjs.org/decode-named-character-reference/-/decode-named-character-reference-1.3.0.tgz",
@@ -7326,6 +7353,7 @@
"url": "https://github.com/sponsors/wooorm" "url": "https://github.com/sponsors/wooorm"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/dedent": { "node_modules/dedent": {
"version": "0.7.0", "version": "0.7.0",
"resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz", "resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz",
@@ -15841,6 +15869,29 @@
} }
} }
}, },
"node_modules/react-redux": {
"version": "9.2.0",
"resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz",
"integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==",
"license": "MIT",
"dependencies": {
"@types/use-sync-external-store": "^0.0.6",
"use-sync-external-store": "^1.4.0"
},
"peerDependencies": {
"@types/react": "^18.2.25 || ^19",
"react": "^18.0 || ^19",
"redux": "^5.0.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"redux": {
"optional": true
}
}
},
"node_modules/react-refresh": { "node_modules/react-refresh": {
"version": "0.11.0", "version": "0.11.0",
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz", "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
@@ -16074,8 +16125,12 @@
"version": "5.0.1", "version": "5.0.1",
"resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz",
"integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==", "integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==",
<<<<<<< HEAD
"license": "MIT"
=======
"license": "MIT", "license": "MIT",
"peer": true "peer": true
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
}, },
"node_modules/redux-thunk": { "node_modules/redux-thunk": {
"version": "3.1.0", "version": "3.1.0",
@@ -18783,6 +18838,8 @@
"node": ">= 0.8" "node": ">= 0.8"
} }
}, },
<<<<<<< HEAD
=======
"node_modules/vfile": { "node_modules/vfile": {
"version": "6.0.3", "version": "6.0.3",
"resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz",
@@ -18811,6 +18868,7 @@
"url": "https://opencollective.com/unified" "url": "https://opencollective.com/unified"
} }
}, },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
"node_modules/victory-vendor": { "node_modules/victory-vendor": {
"version": "37.3.6", "version": "37.3.6",
"resolved": "https://registry.npmjs.org/victory-vendor/-/victory-vendor-37.3.6.tgz", "resolved": "https://registry.npmjs.org/victory-vendor/-/victory-vendor-37.3.6.tgz",

View File

@@ -19,9 +19,13 @@
"react-markdown": "^10.1.0", "react-markdown": "^10.1.0",
"react-router-dom": "^7.13.0", "react-router-dom": "^7.13.0",
"react-scripts": "5.0.1", "react-scripts": "5.0.1",
<<<<<<< HEAD
"recharts": "^3.7.0"
=======
"recharts": "^3.7.0", "recharts": "^3.7.0",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"yaml": "^2.8.2" "yaml": "^2.8.2"
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
}, },
"scripts": { "scripts": {
"start": "react-scripts start", "start": "react-scripts start",

View File

@@ -2,10 +2,11 @@
* ThreatHunt — MUI-powered analyst-assist platform. * ThreatHunt — MUI-powered analyst-assist platform.
*/ */
import React, { useState, useCallback } from 'react'; import React, { useState, useCallback, Suspense } from 'react';
import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom'; import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom';
import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton, import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton,
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip } from '@mui/material'; Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip,
CircularProgress } from '@mui/material';
import MenuIcon from '@mui/icons-material/Menu'; import MenuIcon from '@mui/icons-material/Menu';
import DashboardIcon from '@mui/icons-material/Dashboard'; import DashboardIcon from '@mui/icons-material/Dashboard';
import SearchIcon from '@mui/icons-material/Search'; import SearchIcon from '@mui/icons-material/Search';
@@ -18,6 +19,13 @@ import ScienceIcon from '@mui/icons-material/Science';
import CompareArrowsIcon from '@mui/icons-material/CompareArrows'; import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
import GppMaybeIcon from '@mui/icons-material/GppMaybe'; import GppMaybeIcon from '@mui/icons-material/GppMaybe';
import HubIcon from '@mui/icons-material/Hub'; import HubIcon from '@mui/icons-material/Hub';
<<<<<<< HEAD
import AssessmentIcon from '@mui/icons-material/Assessment';
import TimelineIcon from '@mui/icons-material/Timeline';
import PlaylistAddCheckIcon from '@mui/icons-material/PlaylistAddCheck';
import BookmarksIcon from '@mui/icons-material/Bookmarks';
import ShieldIcon from '@mui/icons-material/Shield';
=======
import DevicesIcon from '@mui/icons-material/Devices'; import DevicesIcon from '@mui/icons-material/Devices';
import AccountTreeIcon from '@mui/icons-material/AccountTree'; import AccountTreeIcon from '@mui/icons-material/AccountTree';
import TimelineIcon from '@mui/icons-material/Timeline'; import TimelineIcon from '@mui/icons-material/Timeline';
@@ -29,9 +37,11 @@ import WorkIcon from '@mui/icons-material/Work';
import NotificationsActiveIcon from '@mui/icons-material/NotificationsActive'; import NotificationsActiveIcon from '@mui/icons-material/NotificationsActive';
import MenuBookIcon from '@mui/icons-material/MenuBook'; import MenuBookIcon from '@mui/icons-material/MenuBook';
import PlaylistPlayIcon from '@mui/icons-material/PlaylistPlay'; import PlaylistPlayIcon from '@mui/icons-material/PlaylistPlay';
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
import { SnackbarProvider } from 'notistack'; import { SnackbarProvider } from 'notistack';
import theme from './theme'; import theme from './theme';
/* -- Eager imports (lightweight, always needed) -- */
import Dashboard from './components/Dashboard'; import Dashboard from './components/Dashboard';
import HuntManager from './components/HuntManager'; import HuntManager from './components/HuntManager';
import DatasetViewer from './components/DatasetViewer'; import DatasetViewer from './components/DatasetViewer';
@@ -42,6 +52,16 @@ import AnnotationPanel from './components/AnnotationPanel';
import HypothesisTracker from './components/HypothesisTracker'; import HypothesisTracker from './components/HypothesisTracker';
import CorrelationView from './components/CorrelationView'; import CorrelationView from './components/CorrelationView';
import AUPScanner from './components/AUPScanner'; import AUPScanner from './components/AUPScanner';
<<<<<<< HEAD
/* -- Lazy imports (heavy: charts, network graph, new feature pages) -- */
const NetworkMap = React.lazy(() => import('./components/NetworkMap'));
const AnalysisDashboard = React.lazy(() => import('./components/AnalysisDashboard'));
const MitreMatrix = React.lazy(() => import('./components/MitreMatrix'));
const TimelineView = React.lazy(() => import('./components/TimelineView'));
const PlaybookManager = React.lazy(() => import('./components/PlaybookManager'));
const SavedSearches = React.lazy(() => import('./components/SavedSearches'));
=======
import NetworkMap from './components/NetworkMap'; import NetworkMap from './components/NetworkMap';
import NetworkPicture from './components/NetworkPicture'; import NetworkPicture from './components/NetworkPicture';
import ProcessTree from './components/ProcessTree'; import ProcessTree from './components/ProcessTree';
@@ -54,12 +74,31 @@ import CaseManager from './components/CaseManager';
import AlertPanel from './components/AlertPanel'; import AlertPanel from './components/AlertPanel';
import InvestigationNotebook from './components/InvestigationNotebook'; import InvestigationNotebook from './components/InvestigationNotebook';
import PlaybookManager from './components/PlaybookManager'; import PlaybookManager from './components/PlaybookManager';
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
const DRAWER_WIDTH = 240; const DRAWER_WIDTH = 240;
interface NavItem { label: string; path: string; icon: React.ReactNode } interface NavItem { label: string; path: string; icon: React.ReactNode }
const NAV: NavItem[] = [ const NAV: NavItem[] = [
<<<<<<< HEAD
{ label: 'Dashboard', path: '/', icon: <DashboardIcon /> },
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
{ label: 'Upload', path: '/upload', icon: <UploadFileIcon /> },
{ label: 'AI Analysis', path: '/analysis', icon: <AssessmentIcon /> },
{ label: 'Agent', path: '/agent', icon: <SmartToyIcon /> },
{ label: 'Enrichment', path: '/enrichment', icon: <SecurityIcon /> },
{ label: 'Annotations', path: '/annotations', icon: <BookmarkIcon /> },
{ label: 'Hypotheses', path: '/hypotheses', icon: <ScienceIcon /> },
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
{ label: 'MITRE Matrix', path: '/mitre', icon: <ShieldIcon /> },
{ label: 'Timeline', path: '/timeline', icon: <TimelineIcon /> },
{ label: 'Playbooks', path: '/playbooks', icon: <PlaylistAddCheckIcon /> },
{ label: 'Saved Searches', path: '/saved-searches', icon: <BookmarksIcon /> },
=======
{ label: 'Dashboard', path: '/', icon: <DashboardIcon /> }, { label: 'Dashboard', path: '/', icon: <DashboardIcon /> },
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> }, { label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> }, { label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
@@ -82,8 +121,17 @@ const NAV: NavItem[] = [
{ label: 'Notebooks', path: '/notebooks', icon: <MenuBookIcon /> }, { label: 'Notebooks', path: '/notebooks', icon: <MenuBookIcon /> },
{ label: 'Playbooks', path: '/playbooks', icon: <PlaylistPlayIcon /> }, { label: 'Playbooks', path: '/playbooks', icon: <PlaylistPlayIcon /> },
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> }, { label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
]; ];
function LazyFallback() {
return (
<Box sx={{ display: 'flex', justifyContent: 'center', alignItems: 'center', minHeight: 200 }}>
<CircularProgress />
</Box>
);
}
function Shell() { function Shell() {
const [open, setOpen] = useState(true); const [open, setOpen] = useState(true);
const navigate = useNavigate(); const navigate = useNavigate();
@@ -102,7 +150,7 @@ function Shell() {
<Typography variant="h6" noWrap sx={{ flexGrow: 1 }}> <Typography variant="h6" noWrap sx={{ flexGrow: 1 }}>
ThreatHunt ThreatHunt
</Typography> </Typography>
<Chip label="v0.3.0" size="small" color="primary" variant="outlined" /> <Chip label="v0.4.0" size="small" color="primary" variant="outlined" />
</Toolbar> </Toolbar>
</AppBar> </AppBar>
@@ -137,6 +185,28 @@ function Shell() {
ml: open ? 0 : `-${DRAWER_WIDTH}px`, ml: open ? 0 : `-${DRAWER_WIDTH}px`,
transition: 'margin 225ms cubic-bezier(0,0,0.2,1)', transition: 'margin 225ms cubic-bezier(0,0,0.2,1)',
}}> }}>
<<<<<<< HEAD
<Suspense fallback={<LazyFallback />}>
<Routes>
<Route path="/" element={<Dashboard />} />
<Route path="/hunts" element={<HuntManager />} />
<Route path="/datasets" element={<DatasetViewer />} />
<Route path="/upload" element={<FileUpload />} />
<Route path="/analysis" element={<AnalysisDashboard />} />
<Route path="/agent" element={<AgentPanel />} />
<Route path="/enrichment" element={<EnrichmentPanel />} />
<Route path="/annotations" element={<AnnotationPanel />} />
<Route path="/hypotheses" element={<HypothesisTracker />} />
<Route path="/correlation" element={<CorrelationView />} />
<Route path="/network" element={<NetworkMap />} />
<Route path="/aup" element={<AUPScanner />} />
<Route path="/mitre" element={<MitreMatrix />} />
<Route path="/timeline" element={<TimelineView />} />
<Route path="/playbooks" element={<PlaybookManager />} />
<Route path="/saved-searches" element={<SavedSearches />} />
</Routes>
</Suspense>
=======
<Routes> <Routes>
<Route path="/" element={<Dashboard />} /> <Route path="/" element={<Dashboard />} />
<Route path="/hunts" element={<HuntManager />} /> <Route path="/hunts" element={<HuntManager />} />
@@ -161,6 +231,7 @@ function Shell() {
<Route path="/playbooks" element={<PlaybookManager />} /> <Route path="/playbooks" element={<PlaybookManager />} />
<Route path="/aup" element={<AUPScanner />} /> <Route path="/aup" element={<AUPScanner />} />
</Routes> </Routes>
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
</Box> </Box>
</Box> </Box>
); );

View File

@@ -71,8 +71,24 @@ export interface Hunt {
dataset_count: number; hypothesis_count: number; dataset_count: number; hypothesis_count: number;
} }
<<<<<<< HEAD
export interface HuntProgress {
hunt_id: string;
status: 'idle' | 'processing' | 'ready';
progress_percent: number;
dataset_total: number;
dataset_completed: number;
dataset_processing: number;
dataset_errors: number;
active_jobs: number;
queued_jobs: number;
network_status: 'none' | 'building' | 'ready';
stages: Record<string, any>;
}
=======
/** Alias kept for backward-compat with components that import HuntOut */ /** Alias kept for backward-compat with components that import HuntOut */
export type HuntOut = Hunt; export type HuntOut = Hunt;
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
export const hunts = { export const hunts = {
list: (skip = 0, limit = 50) => list: (skip = 0, limit = 50) =>
@@ -83,6 +99,7 @@ export const hunts = {
update: (id: string, data: Partial<{ name: string; description: string; status: string }>) => update: (id: string, data: Partial<{ name: string; description: string; status: string }>) =>
api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }), api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }), delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),
}; };
// ── Datasets ───────────────────────────────────────────────────────── // ── Datasets ─────────────────────────────────────────────────────────
@@ -167,6 +184,8 @@ export interface AssistRequest {
active_hypotheses?: string[]; annotations_summary?: string; active_hypotheses?: string[]; annotations_summary?: string;
enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate'; enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate';
model_override?: string; conversation_id?: string; hunt_id?: string; model_override?: string; conversation_id?: string; hunt_id?: string;
execution_preference?: 'auto' | 'force' | 'off';
learning_mode?: boolean;
} }
export interface AssistResponse { export interface AssistResponse {
@@ -175,6 +194,15 @@ export interface AssistResponse {
sans_references: string[]; model_used: string; node_used: string; sans_references: string[]; model_used: string; node_used: string;
latency_ms: number; perspectives: Record<string, any>[] | null; latency_ms: number; perspectives: Record<string, any>[] | null;
conversation_id: string | null; conversation_id: string | null;
execution?: {
scope: string;
datasets_scanned: string[];
policy_hits: number;
result_count: number;
top_domains: string[];
top_user_hosts: string[];
elapsed_ms: number;
} | null;
} }
export interface NodeInfo { url: string; available: boolean } export interface NodeInfo { url: string; available: boolean }
@@ -570,10 +598,12 @@ export interface ScanHit {
theme_name: string; theme_color: string; keyword: string; theme_name: string; theme_color: string; keyword: string;
source_type: string; source_id: string | number; field: string; source_type: string; source_id: string | number; field: string;
matched_value: string; row_index: number | null; dataset_name: string | null; matched_value: string; row_index: number | null; dataset_name: string | null;
hostname?: string | null; username?: string | null;
} }
export interface ScanResponse { export interface ScanResponse {
total_hits: number; hits: ScanHit[]; themes_scanned: number; total_hits: number; hits: ScanHit[]; themes_scanned: number;
keywords_scanned: number; rows_scanned: number; keywords_scanned: number; rows_scanned: number;
cache_used?: boolean; cache_status?: string; cached_at?: string | null;
} }
export const keywords = { export const keywords = {
@@ -607,6 +637,7 @@ export const keywords = {
scan: (opts: { scan: (opts: {
dataset_ids?: string[]; theme_ids?: string[]; dataset_ids?: string[]; theme_ids?: string[];
scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean; scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;
prefer_cache?: boolean; force_rescan?: boolean;
}) => }) =>
api<ScanResponse>('/api/keywords/scan', { api<ScanResponse>('/api/keywords/scan', {
method: 'POST', body: JSON.stringify(opts), method: 'POST', body: JSON.stringify(opts),
@@ -865,6 +896,224 @@ export const notebooks = {
api(`/api/notebooks/${id}`, { method: 'DELETE' }), api(`/api/notebooks/${id}`, { method: 'DELETE' }),
}; };
<<<<<<< HEAD
export interface HostInventory {
hosts: InventoryHost[];
connections: InventoryConnection[];
stats: InventoryStats;
}
export interface InventoryStatus {
hunt_id: string;
status: 'ready' | 'building' | 'none';
}
export interface NetworkSummaryHost {
id: string;
hostname: string;
row_count: number;
ip_count: number;
user_count: number;
}
export interface NetworkSummary {
stats: InventoryStats;
top_hosts: NetworkSummaryHost[];
top_edges: InventoryConnection[];
status?: 'building' | 'deferred';
message?: string;
}
export const network = {
hostInventory: (huntId: string, force = false) =>
api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),
summary: (huntId: string, topN = 20) =>
api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),
subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {
let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;
if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;
return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);
},
inventoryStatus: (huntId: string) =>
api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),
rebuildInventory: (huntId: string) =>
api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),
};
// -- MITRE ATT&CK Coverage (Feature 1) --
export interface MitreTechnique {
id: string;
tactic: string;
sources: { type: string; risk_score?: number; hostname?: string; title?: string }[];
count: number;
}
export interface MitreCoverage {
tactics: string[];
technique_count: number;
detection_count: number;
tactic_coverage: Record<string, { techniques: MitreTechnique[]; count: number }>;
all_techniques: MitreTechnique[];
}
export const mitre = {
coverage: (huntId?: string) => {
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
return api<MitreCoverage>(`/api/mitre/coverage${q}`);
},
};
// -- Timeline (Feature 2) --
export interface TimelineEvent {
timestamp: string;
dataset_id: string;
dataset_name: string;
artifact_type: string;
row_index: number;
hostname: string;
process: string;
summary: string;
data: Record<string, string>;
}
export interface TimelineData {
hunt_id: string;
hunt_name: string;
event_count: number;
datasets: { id: string; name: string; artifact_type: string; row_count: number }[];
events: TimelineEvent[];
}
export const timeline = {
getHuntTimeline: (huntId: string, limit = 2000) =>
api<TimelineData>(`/api/timeline/hunt/${huntId}?limit=${limit}`),
};
// -- Playbooks (Feature 3) --
export interface PlaybookStep {
id: number;
order_index: number;
title: string;
description: string | null;
step_type: string;
target_route: string | null;
is_completed: boolean;
completed_at: string | null;
notes: string | null;
}
export interface PlaybookSummary {
id: string;
name: string;
description: string | null;
is_template: boolean;
hunt_id: string | null;
status: string;
total_steps: number;
completed_steps: number;
created_at: string | null;
}
export interface PlaybookDetail {
id: string;
name: string;
description: string | null;
is_template: boolean;
hunt_id: string | null;
status: string;
created_at: string | null;
steps: PlaybookStep[];
}
export interface PlaybookTemplate {
name: string;
description: string;
steps: { title: string; description: string; step_type: string; target_route: string }[];
}
export const playbooks = {
list: (huntId?: string) => {
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
return api<{ playbooks: PlaybookSummary[] }>(`/api/playbooks${q}`);
},
templates: () => api<{ templates: PlaybookTemplate[] }>('/api/playbooks/templates'),
get: (id: string) => api<PlaybookDetail>(`/api/playbooks/${id}`),
create: (data: { name: string; description?: string; hunt_id?: string; is_template?: boolean; steps?: { title: string; description?: string; step_type?: string; target_route?: string }[] }) =>
api<PlaybookDetail>('/api/playbooks', { method: 'POST', body: JSON.stringify(data) }),
update: (id: string, data: { name?: string; description?: string; status?: string }) =>
api(`/api/playbooks/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
delete: (id: string) => api(`/api/playbooks/${id}`, { method: 'DELETE' }),
updateStep: (stepId: number, data: { is_completed?: boolean; notes?: string }) =>
api(`/api/playbooks/steps/${stepId}`, { method: 'PUT', body: JSON.stringify(data) }),
};
// -- Saved Searches (Feature 5) --
export interface SavedSearchData {
id: string;
name: string;
description: string | null;
search_type: string;
query_params: Record<string, any>;
hunt_id?: string | null;
threshold: number | null;
last_run_at: string | null;
last_result_count: number | null;
created_at: string | null;
}
export interface SearchRunResult {
search_id: string;
search_name: string;
search_type: string;
result_count: number;
previous_count: number;
delta: number;
results: any[];
}
export const savedSearches = {
list: (searchType?: string) => {
const q = searchType ? `?search_type=${encodeURIComponent(searchType)}` : '';
return api<{ searches: SavedSearchData[] }>(`/api/searches${q}`);
},
get: (id: string) => api<SavedSearchData>(`/api/searches/${id}`),
create: (data: { name: string; description?: string; search_type: string; query_params: Record<string, any>; threshold?: number; hunt_id?: string }) =>
api<SavedSearchData>('/api/searches', { method: 'POST', body: JSON.stringify(data) }),
update: (id: string, data: { name?: string; description?: string; search_type?: string; query_params?: Record<string, any>; threshold?: number; hunt_id?: string }) =>
api(`/api/searches/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
delete: (id: string) => api(`/api/searches/${id}`, { method: 'DELETE' }),
run: (id: string) => api<SearchRunResult>(`/api/searches/${id}/run`, { method: 'POST' }),
};
// -- STIX Export --
export const stixExport = {
/** Download a STIX 2.1 bundle JSON for a given hunt */
download: async (huntId: string): Promise<void> => {
const headers: Record<string, string> = {};
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
const res = await fetch(`${BASE}/api/export/stix/${huntId}`, { headers });
if (!res.ok) {
const body = await res.json().catch(() => ({}));
throw new Error(body.detail || `HTTP ${res.status}`);
}
const blob = await res.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `hunt-${huntId}-stix-bundle.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
},
};
=======
export const playbooks = { export const playbooks = {
templates: () => templates: () =>
api<{ templates: PlaybookTemplate[] }>('/api/notebooks/playbooks/templates'), api<{ templates: PlaybookTemplate[] }>('/api/notebooks/playbooks/templates'),
@@ -887,3 +1136,4 @@ export const playbooks = {
abortRun: (runId: string) => abortRun: (runId: string) =>
api<PlaybookRunData>(`/api/notebooks/playbooks/runs/${runId}/abort`, { method: 'POST' }), api<PlaybookRunData>(`/api/notebooks/playbooks/runs/${runId}/abort`, { method: 'POST' }),
}; };
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -188,11 +188,13 @@ const RESULT_COLUMNS: GridColDef[] = [
), ),
}, },
{ field: 'keyword', headerName: 'Keyword', width: 140 }, { field: 'keyword', headerName: 'Keyword', width: 140 },
{ field: 'source_type', headerName: 'Source', width: 120 }, { field: 'dataset_name', headerName: 'Dataset', width: 170 },
{ field: 'dataset_name', headerName: 'Dataset', width: 150 }, { field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
{ field: 'field', headerName: 'Field', width: 130 }, { field: 'field', headerName: 'Field', width: 130 },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 }, { field: 'source_type', headerName: 'Source', width: 120 },
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' }, { field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
]; ];
export default function AUPScanner() { export default function AUPScanner() {
@@ -210,9 +212,9 @@ export default function AUPScanner() {
// Scan options // Scan options
const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set()); const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set());
const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set()); const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set());
const [scanHunts, setScanHunts] = useState(true); const [scanHunts, setScanHunts] = useState(false);
const [scanAnnotations, setScanAnnotations] = useState(true); const [scanAnnotations, setScanAnnotations] = useState(false);
const [scanMessages, setScanMessages] = useState(true); const [scanMessages, setScanMessages] = useState(false);
// Load themes + hunts // Load themes + hunts
const loadData = useCallback(async () => { const loadData = useCallback(async () => {
@@ -224,9 +226,13 @@ export default function AUPScanner() {
]); ]);
setThemes(tRes.themes); setThemes(tRes.themes);
setHuntList(hRes.hunts); setHuntList(hRes.hunts);
if (!selectedHuntId && hRes.hunts.length > 0) {
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
setSelectedHuntId(best.id);
}
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); } } catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
setLoading(false); setLoading(false);
}, [enqueueSnackbar]); }, [enqueueSnackbar, selectedHuntId]);
useEffect(() => { loadData(); }, [loadData]); useEffect(() => { loadData(); }, [loadData]);
@@ -237,7 +243,7 @@ export default function AUPScanner() {
datasets.list(0, 500, selectedHuntId).then(res => { datasets.list(0, 500, selectedHuntId).then(res => {
if (cancelled) return; if (cancelled) return;
setDsList(res.datasets); setDsList(res.datasets);
setSelectedDs(new Set(res.datasets.map(d => d.id))); setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
}).catch(() => {}); }).catch(() => {});
return () => { cancelled = true; }; return () => { cancelled = true; };
}, [selectedHuntId]); }, [selectedHuntId]);
@@ -251,6 +257,15 @@ export default function AUPScanner() {
// Run scan // Run scan
const runScan = useCallback(async () => { const runScan = useCallback(async () => {
if (!selectedHuntId) {
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
return;
}
if (selectedDs.size === 0) {
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
return;
}
setScanning(true); setScanning(true);
setScanResult(null); setScanResult(null);
try { try {
@@ -260,6 +275,7 @@ export default function AUPScanner() {
scan_hunts: scanHunts, scan_hunts: scanHunts,
scan_annotations: scanAnnotations, scan_annotations: scanAnnotations,
scan_messages: scanMessages, scan_messages: scanMessages,
prefer_cache: true,
}); });
setScanResult(res); setScanResult(res);
enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, { enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, {
@@ -267,7 +283,7 @@ export default function AUPScanner() {
}); });
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); } } catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
setScanning(false); setScanning(false);
}, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]); }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>; if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>;
@@ -316,9 +332,38 @@ export default function AUPScanner() {
)} )}
{!selectedHuntId && ( {!selectedHuntId && (
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}> <Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
All datasets will be scanned if no hunt is selected Select a hunt to enable scoped scanning
</Typography> </Typography>
)} )}
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
<Select
labelId="aup-dataset-label"
multiple
value={Array.from(selectedDs)}
label="Datasets"
renderValue={(selected) => `${(selected as string[]).length} selected`}
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
>
{dsList.map(d => (
<MenuItem key={d.id} value={d.id}>
<Checkbox size="small" checked={selectedDs.has(d.id)} />
<Typography variant="body2" sx={{ ml: 0.5 }}>
{d.name} ({d.row_count.toLocaleString()} rows)
</Typography>
</MenuItem>
))}
</Select>
</FormControl>
{selectedHuntId && dsList.length > 0 && (
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
</Stack>
)}
</Box> </Box>
{/* Theme selector */} {/* Theme selector */}
@@ -372,7 +417,7 @@ export default function AUPScanner() {
<Button <Button
variant="contained" color="warning" size="large" variant="contained" color="warning" size="large"
startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />} startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />}
onClick={runScan} disabled={scanning} onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
> >
{scanning ? 'Scanning…' : 'Run Scan'} {scanning ? 'Scanning…' : 'Run Scan'}
</Button> </Button>
@@ -392,6 +437,15 @@ export default function AUPScanner() {
<strong>{scanResult.total_hits}</strong> hits across{' '} <strong>{scanResult.total_hits}</strong> hits across{' '}
<strong>{scanResult.rows_scanned}</strong> rows |{' '} <strong>{scanResult.rows_scanned}</strong> rows |{' '}
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned {scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
{scanResult.cache_status && (
<Chip
size="small"
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
sx={{ ml: 1, height: 20 }}
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
variant="outlined"
/>
)}
</Alert> </Alert>
)} )}

Some files were not shown because too many files have changed in this diff Show More