diff --git a/Dockerfile.frontend b/Dockerfile.frontend index c99551c..4b80981 100644 --- a/Dockerfile.frontend +++ b/Dockerfile.frontend @@ -17,7 +17,7 @@ COPY frontend/tsconfig.json ./ # Build application RUN npm run build -# Production stage — nginx reverse-proxy + static files +# Production stage — nginx reverse-proxy + static files FROM nginx:alpine # Copy built React app diff --git a/_add_label_filter_networkmap.py b/_add_label_filter_networkmap.py new file mode 100644 index 0000000..49af706 --- /dev/null +++ b/_add_label_filter_networkmap.py @@ -0,0 +1,148 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') + +# 1) Add label mode type near graph types +marker="interface GEdge { source: string; target: string; weight: number }\ninterface Graph { nodes: GNode[]; edges: GEdge[] }\n" +if marker in t and "type LabelMode" not in t: + t=t.replace(marker, marker+"\ntype LabelMode = 'all' | 'highlight' | 'none';\n") + +# 2) extend drawLabels signature +old_sig="""function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, + simplify: boolean, +) { +""" +new_sig="""function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, + simplify: boolean, labelMode: LabelMode, +) { +""" +if old_sig in t: + t=t.replace(old_sig,new_sig) + +# 3) label mode guards inside drawLabels +old_guard=""" const dimmed = search.length > 0; + if (simplify && !search && !hovered && !selected) { + return; + } +""" +new_guard=""" if (labelMode === 'none') return; + const dimmed = search.length > 0; + if (labelMode === 'highlight' && !search && !hovered && !selected) return; + if (simplify && labelMode !== 'all' && !search && !hovered && !selected) { + return; + } +""" +if old_guard in t: + t=t.replace(old_guard,new_guard) + +old_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id); + const show = isHighlight || n.meta.type === 'host' || n.count >= 2; + if (!show) continue; +""" +new_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id); + const show = labelMode === 'all' + ? (isHighlight || n.meta.type === 'host' || n.count >= 2) + : isHighlight; + if (!show) continue; +""" +if old_show in t: + t=t.replace(old_show,new_show) + +# 4) drawGraph signature and call site +old_graph_sig="""function drawGraph( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, search: string, + vp: Viewport, animTime: number, dpr: number, +) { +""" +new_graph_sig="""function drawGraph( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, search: string, + vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode, +) { +""" +if old_graph_sig in t: + t=t.replace(old_graph_sig,new_graph_sig) + +old_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);" +new_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode);" +if old_drawlabels_call in t: + t=t.replace(old_drawlabels_call,new_drawlabels_call) + +# 5) state for label mode +state_anchor=" const [selectedNode, setSelectedNode] = useState(null);\n const [search, setSearch] = useState('');\n" +state_new=" const [selectedNode, setSelectedNode] = useState(null);\n const [search, setSearch] = useState('');\n const [labelMode, setLabelMode] = useState('highlight');\n" +if state_anchor in t: + t=t.replace(state_anchor,state_new) + +# 6) pass labelMode in draw calls +old_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);" +new_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode);" +if old_tick_draw in t: + t=t.replace(old_tick_draw,new_tick_draw) + +old_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);" +new_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode);" +if old_redraw_draw in t: + t=t.replace(old_redraw_draw,new_redraw_draw) + +# 7) include labelMode in redraw deps +old_redraw_dep="] , [graph, hovered, selectedNode, search]);" +if old_redraw_dep in t: + t=t.replace(old_redraw_dep, "] , [graph, hovered, selectedNode, search, labelMode]);") +else: + t=t.replace(" }, [graph, hovered, selectedNode, search]);"," }, [graph, hovered, selectedNode, search, labelMode]);") + +# 8) Add toolbar selector after search field +search_block=""" setSearch(e.target.value)} + sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }} + slotProps={{ + input: { + startAdornment: , + }, + }} + /> +""" +label_block=""" setSearch(e.target.value)} + sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }} + slotProps={{ + input: { + startAdornment: , + }, + }} + /> + + + Labels + + +""" +if search_block in t: + t=t.replace(search_block,label_block) + +p.write_text(t,encoding='utf-8') +print('added network map label filter control and renderer modes') diff --git a/_add_scanner_budget_config.py b/_add_scanner_budget_config.py new file mode 100644 index 0000000..9c8c99f --- /dev/null +++ b/_add_scanner_budget_config.py @@ -0,0 +1,18 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=p.read_text(encoding='utf-8') +old=''' # -- Scanner settings ----------------------------------------------- + SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch") +''' +new=''' # -- Scanner settings ----------------------------------------------- + SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch") + SCANNER_MAX_ROWS_PER_SCAN: int = Field( + default=300000, + description="Global row budget for a single AUP scan request (0 = unlimited)", + ) +''' +if old not in t: + raise SystemExit('scanner settings block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('added SCANNER_MAX_ROWS_PER_SCAN config') diff --git a/_apply_frontend_scale_patch.py b/_apply_frontend_scale_patch.py new file mode 100644 index 0000000..a2374be --- /dev/null +++ b/_apply_frontend_scale_patch.py @@ -0,0 +1,46 @@ +from pathlib import Path + +root = Path(r"d:\Projects\Dev\ThreatHunt") + +# -------- client.ts -------- +client = root / "frontend/src/api/client.ts" +text = client.read_text(encoding="utf-8") + +if "export interface NetworkSummary" not in text: + insert_after = "export interface InventoryStatus {\n hunt_id: string;\n status: 'ready' | 'building' | 'none';\n}\n" + addition = insert_after + "\nexport interface NetworkSummaryHost {\n id: string;\n hostname: string;\n row_count: number;\n ip_count: number;\n user_count: number;\n}\n\nexport interface NetworkSummary {\n stats: InventoryStats;\n top_hosts: NetworkSummaryHost[];\n top_edges: InventoryConnection[];\n status?: 'building' | 'deferred';\n message?: string;\n}\n" + text = text.replace(insert_after, addition) + +net_old = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n inventoryStatus: (huntId: string) =>\n api(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};""" +net_new = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n summary: (huntId: string, topN = 20) =>\n api(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),\n subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {\n let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;\n if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;\n return api(qs);\n },\n inventoryStatus: (huntId: string) =>\n api(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};""" +if net_old in text: + text = text.replace(net_old, net_new) + +client.write_text(text, encoding="utf-8") + +# -------- NetworkMap.tsx -------- +nm = root / "frontend/src/components/NetworkMap.tsx" +text = nm.read_text(encoding="utf-8") + +# add constants +if "LARGE_HUNT_HOST_THRESHOLD" not in text: + text = text.replace("let lastSelectedHuntId = '';\n", "let lastSelectedHuntId = '';\nconst LARGE_HUNT_HOST_THRESHOLD = 400;\nconst LARGE_HUNT_SUBGRAPH_HOSTS = 350;\nconst LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n") + +# inject helper in component after sleep +marker = " const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));\n" +if "loadScaleAwareGraph" not in text: + helper = marker + "\n const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n setLoading(true); setError(''); setGraph(null); setStats(null);\n setSelectedNode(null); setPopoverAnchor(null);\n\n const waitReadyThen = async (fn: () => Promise): Promise => {\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const out: any = await fn();\n if (out && !out.status) return out as T;\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') {\n const out2: any = await fn();\n if (out2 && !out2.status) return out2 as T;\n }\n if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes');\n const jitter = Math.floor(Math.random() * 250);\n await sleep(delayMs + jitter);\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n };\n\n try {\n setProgress('Loading network summary');\n const summary: any = await waitReadyThen(() => network.summary(huntId, 20));\n const totalHosts = summary?.stats?.total_hosts || 0;\n\n if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) {\n setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`);\n const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES));\n if (!sub?.hosts || sub.hosts.length === 0) {\n setError('No hosts found for subgraph.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() });\n setGraph(g);\n return;\n }\n\n // Small/medium hunts: load full inventory\n setProgress('Loading host inventory');\n const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh));\n if (!inv?.hosts || inv.hosts.length === 0) {\n setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats || inv.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() });\n setGraph(g);\n } catch (e: any) {\n console.error('[NetworkMap] scale-aware load error:', e);\n setError(e.message || 'Failed to load network data');\n } finally {\n setLoading(false);\n setProgress('');\n }\n }, []);\n" + text = text.replace(marker, helper) + +# simplify existing loadGraph function body to delegate +pattern_start = text.find(" // Load host inventory for selected hunt (with cache).") +if pattern_start != -1: + # replace the whole loadGraph useCallback block by simple delegator + import re + block_re = re.compile(r" // Load host inventory for selected hunt \(with cache\)\.[\s\S]*?\n \}, \[\]\); // Stable - reads canvasSizeRef, no state deps\n", re.M) + repl = " // Load graph data for selected hunt (delegates to scale-aware loader).\n const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n if (!huntId) return;\n\n // Check module-level cache first (5 min TTL)\n if (!forceRefresh) {\n const cached = graphCache.get(huntId);\n if (cached && Date.now() - cached.ts < 5 * 60 * 1000) {\n setGraph(cached.graph);\n setStats(cached.stats);\n setError('');\n simAlphaRef.current = 0;\n return;\n }\n }\n\n await loadScaleAwareGraph(huntId, forceRefresh);\n // eslint-disable-next-line react-hooks/exhaustive-deps\n }, []); // Stable - reads canvasSizeRef, no state deps\n" + text = block_re.sub(repl, text, count=1) + +nm.write_text(text, encoding="utf-8") + +print("Patched frontend client + NetworkMap for scale-aware loading") \ No newline at end of file diff --git a/_apply_phase1_patch.py b/_apply_phase1_patch.py new file mode 100644 index 0000000..a446c58 --- /dev/null +++ b/_apply_phase1_patch.py @@ -0,0 +1,206 @@ +from pathlib import Path + +root = Path(r"d:\Projects\Dev\ThreatHunt") + +# 1) config.py additions +cfg = root / "backend/app/config.py" +text = cfg.read_text(encoding="utf-8") +needle = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n" +insert = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n\n # -- Job queue settings ----------------------------------------------\n JOB_QUEUE_MAX_BACKLOG: int = Field(\n default=2000, description=\"Soft cap for queued background jobs\"\n )\n JOB_QUEUE_RETAIN_COMPLETED: int = Field(\n default=3000, description=\"Maximum completed/failed jobs to retain in memory\"\n )\n JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(\n default=60, description=\"How often to run in-memory job cleanup\"\n )\n JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n" +if needle in text: + text = text.replace(needle, insert) +cfg.write_text(text, encoding="utf-8") + +# 2) scanner.py default scope = dataset-only +scanner = root / "backend/app/services/scanner.py" +text = scanner.read_text(encoding="utf-8") +text = text.replace(" scan_hunts: bool = True,", " scan_hunts: bool = False,") +text = text.replace(" scan_annotations: bool = True,", " scan_annotations: bool = False,") +text = text.replace(" scan_messages: bool = True,", " scan_messages: bool = False,") +scanner.write_text(text, encoding="utf-8") + +# 3) keywords.py defaults = dataset-only +kw = root / "backend/app/api/routes/keywords.py" +text = kw.read_text(encoding="utf-8") +text = text.replace(" scan_hunts: bool = True", " scan_hunts: bool = False") +text = text.replace(" scan_annotations: bool = True", " scan_annotations: bool = False") +text = text.replace(" scan_messages: bool = True", " scan_messages: bool = False") +kw.write_text(text, encoding="utf-8") + +# 4) job_queue.py dedupe + periodic cleanup +jq = root / "backend/app/services/job_queue.py" +text = jq.read_text(encoding="utf-8") + +text = text.replace( +"from typing import Any, Callable, Coroutine, Optional\n", +"from typing import Any, Callable, Coroutine, Optional\n\nfrom app.config import settings\n" +) + +text = text.replace( +" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n", +" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n self._cleanup_task: asyncio.Task | None = None\n" +) + +start_old = ''' async def start(self): + if self._started: + return + self._started = True + for i in range(self._max_workers): + task = asyncio.create_task(self._worker(i)) + self._workers.append(task) + logger.info(f"Job queue started with {self._max_workers} workers") +''' +start_new = ''' async def start(self): + if self._started: + return + self._started = True + for i in range(self._max_workers): + task = asyncio.create_task(self._worker(i)) + self._workers.append(task) + if not self._cleanup_task or self._cleanup_task.done(): + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info(f"Job queue started with {self._max_workers} workers") +''' +text = text.replace(start_old, start_new) + +stop_old = ''' async def stop(self): + self._started = False + for w in self._workers: + w.cancel() + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + logger.info("Job queue stopped") +''' +stop_new = ''' async def stop(self): + self._started = False + for w in self._workers: + w.cancel() + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + if self._cleanup_task: + self._cleanup_task.cancel() + await asyncio.gather(self._cleanup_task, return_exceptions=True) + self._cleanup_task = None + logger.info("Job queue stopped") +''' +text = text.replace(stop_old, stop_new) + +submit_old = ''' def submit(self, job_type: JobType, **params) -> Job: + job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params) + self._jobs[job.id] = job + self._queue.put_nowait(job.id) + logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") + return job +''' +submit_new = ''' def submit(self, job_type: JobType, **params) -> Job: + # Soft backpressure: prefer dedupe over queue amplification + dedupe_job = self._find_active_duplicate(job_type, params) + if dedupe_job is not None: + logger.info( + f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}" + ) + return dedupe_job + + if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG: + logger.warning( + "Job queue backlog high (%d >= %d). Accepting job but system may be degraded.", + self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG, + ) + + job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params) + self._jobs[job.id] = job + self._queue.put_nowait(job.id) + logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") + return job +''' +text = text.replace(submit_old, submit_new) + +insert_methods_after = " def get_job(self, job_id: str) -> Job | None:\n return self._jobs.get(job_id)\n" +new_methods = ''' def get_job(self, job_id: str) -> Job | None: + return self._jobs.get(job_id) + + def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None: + """Return queued/running job with same key workload to prevent duplicate storms.""" + key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"] + sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None) + if not sig: + return None + for j in self._jobs.values(): + if j.job_type != job_type: + continue + if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING): + continue + other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None) + if sig == other_sig: + return j + return None +''' +text = text.replace(insert_methods_after, new_methods) + +cleanup_old = ''' def cleanup(self, max_age_seconds: float = 3600): + now = time.time() + to_remove = [ + jid for jid, j in self._jobs.items() + if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + and (now - j.created_at) > max_age_seconds + ] + for jid in to_remove: + del self._jobs[jid] + if to_remove: + logger.info(f"Cleaned up {len(to_remove)} old jobs") +''' +cleanup_new = ''' def cleanup(self, max_age_seconds: float = 3600): + now = time.time() + terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + to_remove = [ + jid for jid, j in self._jobs.items() + if j.status in terminal_states and (now - j.created_at) > max_age_seconds + ] + + # Also cap retained terminal jobs to avoid unbounded memory growth + terminal_jobs = sorted( + [j for j in self._jobs.values() if j.status in terminal_states], + key=lambda j: j.created_at, + reverse=True, + ) + overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :] + to_remove.extend([j.id for j in overflow]) + + removed = 0 + for jid in set(to_remove): + if jid in self._jobs: + del self._jobs[jid] + removed += 1 + if removed: + logger.info(f"Cleaned up {removed} old jobs") + + async def _cleanup_loop(self): + interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS) + while self._started: + try: + self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS) + except Exception as e: + logger.warning(f"Job queue cleanup loop error: {e}") + await asyncio.sleep(interval) +''' +text = text.replace(cleanup_old, cleanup_new) + +jq.write_text(text, encoding="utf-8") + +# 5) NetworkMap polling backoff/jitter max wait +nm = root / "frontend/src/components/NetworkMap.tsx" +text = nm.read_text(encoding="utf-8") + +text = text.replace( +" // Poll until ready, then re-fetch\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n }\n", +" // Poll until ready (exponential backoff), then re-fetch\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n throw new Error('Host inventory build timed out after 5 minutes');\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n" +) + +text = text.replace( +" const waitUntilReady = async (): Promise => {\n // Poll inventory-status every 2s until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n // still building or none (job may not have started yet) - keep polling\n } catch { if (cancelled) return false; }\n }\n };\n", +" const waitUntilReady = async (): Promise => {\n // Poll inventory-status with exponential backoff until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n setError('Host inventory build timed out. Please retry.');\n return false;\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n // still building or none (job may not have started yet) - keep polling\n } catch {\n if (cancelled) return false;\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n }\n };\n" +) + +nm.write_text(text, encoding="utf-8") + +print("Patched: config.py, scanner.py, keywords.py, job_queue.py, NetworkMap.tsx") \ No newline at end of file diff --git a/_apply_phase2_patch.py b/_apply_phase2_patch.py new file mode 100644 index 0000000..52cae81 --- /dev/null +++ b/_apply_phase2_patch.py @@ -0,0 +1,207 @@ +from pathlib import Path +import re + +root = Path(r"d:\Projects\Dev\ThreatHunt") + +# ---------- config.py ---------- +cfg = root / "backend/app/config.py" +text = cfg.read_text(encoding="utf-8") +marker = " JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n" +add = marker + "\n # -- Startup throttling ------------------------------------------------\n STARTUP_WARMUP_MAX_HUNTS: int = Field(\n default=5, description=\"Max hunts to warm inventory cache for at startup\"\n )\n STARTUP_REPROCESS_MAX_DATASETS: int = Field(\n default=25, description=\"Max unprocessed datasets to enqueue at startup\"\n )\n\n # -- Network API scale guards -----------------------------------------\n NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(\n default=400, description=\"Hard cap for hosts returned by network subgraph endpoint\"\n )\n NETWORK_SUBGRAPH_MAX_EDGES: int = Field(\n default=3000, description=\"Hard cap for edges returned by network subgraph endpoint\"\n )\n" +if marker in text and "STARTUP_WARMUP_MAX_HUNTS" not in text: + text = text.replace(marker, add) +cfg.write_text(text, encoding="utf-8") + +# ---------- job_queue.py ---------- +jq = root / "backend/app/services/job_queue.py" +text = jq.read_text(encoding="utf-8") + +# add helper methods after get_stats +anchor = " def get_stats(self) -> dict:\n by_status = {}\n for j in self._jobs.values():\n by_status[j.status.value] = by_status.get(j.status.value, 0) + 1\n return {\n \"total\": len(self._jobs),\n \"queued\": self._queue.qsize(),\n \"by_status\": by_status,\n \"workers\": self._max_workers,\n \"active_workers\": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),\n }\n" +if "def is_backlogged(" not in text: + insert = anchor + "\n def is_backlogged(self) -> bool:\n return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG\n\n def can_accept(self, reserve: int = 0) -> bool:\n return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG\n" + text = text.replace(anchor, insert) + +jq.write_text(text, encoding="utf-8") + +# ---------- host_inventory.py keyset pagination ---------- +hi = root / "backend/app/services/host_inventory.py" +text = hi.read_text(encoding="utf-8") +old = ''' batch_size = 5000 + offset = 0 + while True: + rr = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .order_by(DatasetRow.row_index) + .offset(offset).limit(batch_size) + ) + rows = rr.scalars().all() + if not rows: + break +''' +new = ''' batch_size = 5000 + last_row_index = -1 + while True: + rr = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .where(DatasetRow.row_index > last_row_index) + .order_by(DatasetRow.row_index) + .limit(batch_size) + ) + rows = rr.scalars().all() + if not rows: + break +''' +if old in text: + text = text.replace(old, new) +text = text.replace(" offset += batch_size\n if len(rows) < batch_size:\n break\n", " last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n") +hi.write_text(text, encoding="utf-8") + +# ---------- network.py add summary/subgraph + backpressure ---------- +net = root / "backend/app/api/routes/network.py" +text = net.read_text(encoding="utf-8") +text = text.replace("from fastapi import APIRouter, Depends, HTTPException, Query", "from fastapi import APIRouter, Depends, HTTPException, Query") +if "from app.config import settings" not in text: + text = text.replace("from app.db import get_db\n", "from app.config import settings\nfrom app.db import get_db\n") + +# add helpers and endpoints before inventory-status endpoint +if "def _build_summary" not in text: + helper_block = ''' + +def _build_summary(inv: dict, top_n: int = 20) -> dict: + hosts = inv.get("hosts", []) + conns = inv.get("connections", []) + top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n] + top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n] + return { + "stats": inv.get("stats", {}), + "top_hosts": [ + { + "id": h.get("id"), + "hostname": h.get("hostname"), + "row_count": h.get("row_count", 0), + "ip_count": len(h.get("ips", [])), + "user_count": len(h.get("users", [])), + } + for h in top_hosts + ], + "top_edges": top_edges, + } + + +def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict: + hosts = inv.get("hosts", []) + conns = inv.get("connections", []) + + max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS)) + max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES)) + + if node_id: + rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id] + rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges] + ids = {node_id} + for c in rel_edges: + ids.add(c.get("source")) + ids.add(c.get("target")) + rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts] + else: + rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts] + allowed = {h.get("id") for h in rel_hosts} + rel_edges = [ + c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True) + if c.get("source") in allowed and c.get("target") in allowed + ][:max_edges] + + return { + "hosts": rel_hosts, + "connections": rel_edges, + "stats": { + **inv.get("stats", {}), + "subgraph_hosts": len(rel_hosts), + "subgraph_connections": len(rel_edges), + "truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns), + }, + } + + +@router.get("/summary") +async def get_inventory_summary( + hunt_id: str = Query(..., description="Hunt ID"), + top_n: int = Query(20, ge=1, le=200), +): + """Return a lightweight summary view for large hunts.""" + cached = inventory_cache.get(hunt_id) + if cached is None: + if not inventory_cache.is_building(hunt_id): + if job_queue.is_backlogged(): + return JSONResponse( + status_code=202, + content={"status": "deferred", "message": "Queue busy, retry shortly"}, + ) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return JSONResponse(status_code=202, content={"status": "building"}) + return _build_summary(cached, top_n=top_n) + + +@router.get("/subgraph") +async def get_inventory_subgraph( + hunt_id: str = Query(..., description="Hunt ID"), + node_id: str | None = Query(None, description="Optional focal node"), + max_hosts: int = Query(200, ge=1, le=5000), + max_edges: int = Query(1500, ge=1, le=20000), +): + """Return a bounded subgraph for scale-safe rendering.""" + cached = inventory_cache.get(hunt_id) + if cached is None: + if not inventory_cache.is_building(hunt_id): + if job_queue.is_backlogged(): + return JSONResponse( + status_code=202, + content={"status": "deferred", "message": "Queue busy, retry shortly"}, + ) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return JSONResponse(status_code=202, content={"status": "building"}) + return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges) +''' + text = text.replace("\n\n@router.get(\"/inventory-status\")", helper_block + "\n\n@router.get(\"/inventory-status\")") + +# add backpressure in host-inventory enqueue points +text = text.replace( +" if not inventory_cache.is_building(hunt_id):\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)", +" if not inventory_cache.is_building(hunt_id):\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)" +) +text = text.replace( +" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)", +" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)" +) + +net.write_text(text, encoding="utf-8") + +# ---------- analysis.py backpressure on manual submit ---------- +analysis = root / "backend/app/api/routes/analysis.py" +text = analysis.read_text(encoding="utf-8") +text = text.replace( +" job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}", +" if not job_queue.can_accept():\n raise HTTPException(status_code=429, detail=\"Job queue is busy. Retry shortly.\")\n job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}" +) +analysis.write_text(text, encoding="utf-8") + +# ---------- main.py startup throttles ---------- +main = root / "backend/app/main.py" +text = main.read_text(encoding="utf-8") + +text = text.replace( +" for hid in hunt_ids:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if hunt_ids:\n logger.info(f\"Queued host inventory warm-up for {len(hunt_ids)} hunts\")", +" warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]\n for hid in warm_hunts:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if warm_hunts:\n logger.info(f\"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})\")" +) + +text = text.replace( +" if unprocessed_ids:\n for ds_id in unprocessed_ids:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(unprocessed_ids))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()", +" if unprocessed_ids:\n to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]\n for ds_id in to_reprocess:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(to_reprocess))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()" +) + +main.write_text(text, encoding="utf-8") + +print("Patched Phase 2 files") \ No newline at end of file diff --git a/_aup_add_dataset_scope_ui.py b/_aup_add_dataset_scope_ui.py new file mode 100644 index 0000000..cf413bf --- /dev/null +++ b/_aup_add_dataset_scope_ui.py @@ -0,0 +1,75 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx') +t=p.read_text(encoding='utf-8') + +# default selection when hunt changes: first 3 datasets instead of all +old=''' datasets.list(0, 500, selectedHuntId).then(res => { + if (cancelled) return; + setDsList(res.datasets); + setSelectedDs(new Set(res.datasets.map(d => d.id))); + }).catch(() => {}); +''' +new=''' datasets.list(0, 500, selectedHuntId).then(res => { + if (cancelled) return; + setDsList(res.datasets); + setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id))); + }).catch(() => {}); +''' +if old not in t: + raise SystemExit('hunt-change dataset init block not found') +t=t.replace(old,new) + +# insert dataset scope multi-select under hunt info +anchor=''' {!selectedHuntId && ( + + All datasets will be scanned if no hunt is selected + + )} + + + {/* Theme selector */} +''' +insert=''' {!selectedHuntId && ( + + Select a hunt to enable scoped scanning + + )} + + + Datasets + + + + {selectedHuntId && dsList.length > 0 && ( + + + + + + )} + + + {/* Theme selector */} +''' +if anchor not in t: + raise SystemExit('dataset scope anchor not found') +t=t.replace(anchor,insert) + +p.write_text(t,encoding='utf-8') +print('added AUP dataset multi-select scoping and safer defaults') diff --git a/_aup_add_host_user_to_hits.py b/_aup_add_host_user_to_hits.py new file mode 100644 index 0000000..ab3e45c --- /dev/null +++ b/_aup_add_host_user_to_hits.py @@ -0,0 +1,182 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py') +t=p.read_text(encoding='utf-8') + +# 1) Extend ScanHit dataclass +old='''@dataclass +class ScanHit: + theme_name: str + theme_color: str + keyword: str + source_type: str # dataset_row | hunt | annotation | message + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None +''' +new='''@dataclass +class ScanHit: + theme_name: str + theme_color: str + keyword: str + source_type: str # dataset_row | hunt | annotation | message + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None + hostname: str | None = None + username: str | None = None +''' +if old not in t: + raise SystemExit('ScanHit dataclass block not found') +t=t.replace(old,new) + +# 2) Add helper to infer hostname/user from a row +insert_after='''BATCH_SIZE = 200 + + +@dataclass +class ScanHit: +''' +helper='''BATCH_SIZE = 200 + + +def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]: + """Best-effort extraction of hostname and user from a dataset row.""" + if not data: + return None, None + + host_keys = ( + 'hostname', 'host_name', 'host', 'computer_name', 'computer', + 'fqdn', 'client_id', 'agent_id', 'endpoint_id', + ) + user_keys = ( + 'username', 'user_name', 'user', 'account_name', + 'logged_in_user', 'samaccountname', 'sam_account_name', + ) + + def pick(keys): + for k in keys: + for actual_key, v in data.items(): + if actual_key.lower() == k and v not in (None, ''): + return str(v) + return None + + return pick(host_keys), pick(user_keys) + + +@dataclass +class ScanHit: +''' +if insert_after in t and '_infer_hostname_and_user' not in t: + t=t.replace(insert_after,helper) + +# 3) Extend _match_text signature and ScanHit construction +old_sig=''' def _match_text( + self, + text: str, + patterns: dict, + source_type: str, + source_id: str | int, + field_name: str, + hits: list[ScanHit], + row_index: int | None = None, + dataset_name: str | None = None, + ) -> None: +''' +new_sig=''' def _match_text( + self, + text: str, + patterns: dict, + source_type: str, + source_id: str | int, + field_name: str, + hits: list[ScanHit], + row_index: int | None = None, + dataset_name: str | None = None, + hostname: str | None = None, + username: str | None = None, + ) -> None: +''' +if old_sig not in t: + raise SystemExit('_match_text signature not found') +t=t.replace(old_sig,new_sig) + +old_hit=''' hits.append(ScanHit( + theme_name=theme_name, + theme_color=theme_color, + keyword=kw_value, + source_type=source_type, + source_id=source_id, + field=field_name, + matched_value=matched_preview, + row_index=row_index, + dataset_name=dataset_name, + )) +''' +new_hit=''' hits.append(ScanHit( + theme_name=theme_name, + theme_color=theme_color, + keyword=kw_value, + source_type=source_type, + source_id=source_id, + field=field_name, + matched_value=matched_preview, + row_index=row_index, + dataset_name=dataset_name, + hostname=hostname, + username=username, + )) +''' +if old_hit not in t: + raise SystemExit('ScanHit append block not found') +t=t.replace(old_hit,new_hit) + +# 4) Pass inferred hostname/username in dataset scan path +old_call=''' for row in rows: + result.rows_scanned += 1 + data = row.data or {} + for col_name, cell_value in data.items(): + if cell_value is None: + continue + text = str(cell_value) + self._match_text( + text, + patterns, + "dataset_row", + row.id, + col_name, + result.hits, + row_index=row.row_index, + dataset_name=ds_name, + ) +''' +new_call=''' for row in rows: + result.rows_scanned += 1 + data = row.data or {} + hostname, username = _infer_hostname_and_user(data) + for col_name, cell_value in data.items(): + if cell_value is None: + continue + text = str(cell_value) + self._match_text( + text, + patterns, + "dataset_row", + row.id, + col_name, + result.hits, + row_index=row.row_index, + dataset_name=ds_name, + hostname=hostname, + username=username, + ) +''' +if old_call not in t: + raise SystemExit('dataset _match_text call block not found') +t=t.replace(old_call,new_call) + +p.write_text(t,encoding='utf-8') +print('updated scanner hits with hostname+username context') diff --git a/_aup_extend_scanhit_api.py b/_aup_extend_scanhit_api.py new file mode 100644 index 0000000..e7ecac6 --- /dev/null +++ b/_aup_extend_scanhit_api.py @@ -0,0 +1,32 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py') +t=p.read_text(encoding='utf-8') +old='''class ScanHit(BaseModel): + theme_name: str + theme_color: str + keyword: str + source_type: str + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None +''' +new='''class ScanHit(BaseModel): + theme_name: str + theme_color: str + keyword: str + source_type: str + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None + hostname: str | None = None + username: str | None = None +''' +if old not in t: + raise SystemExit('ScanHit pydantic model block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('extended API ScanHit model with hostname+username') diff --git a/_aup_extend_scanhit_frontend_type.py b/_aup_extend_scanhit_frontend_type.py new file mode 100644 index 0000000..a0f85e7 --- /dev/null +++ b/_aup_extend_scanhit_frontend_type.py @@ -0,0 +1,21 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts') +t=p.read_text(encoding='utf-8') +old='''export interface ScanHit { + theme_name: string; theme_color: string; keyword: string; + source_type: string; source_id: string | number; field: string; + matched_value: string; row_index: number | null; dataset_name: string | null; +} +''' +new='''export interface ScanHit { + theme_name: string; theme_color: string; keyword: string; + source_type: string; source_id: string | number; field: string; + matched_value: string; row_index: number | null; dataset_name: string | null; + hostname?: string | null; username?: string | null; +} +''' +if old not in t: + raise SystemExit('frontend ScanHit interface block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('extended frontend ScanHit type with hostname+username') diff --git a/_aup_keywords_scope_and_missing.py b/_aup_keywords_scope_and_missing.py new file mode 100644 index 0000000..f03410e --- /dev/null +++ b/_aup_keywords_scope_and_missing.py @@ -0,0 +1,57 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py') +t=p.read_text(encoding='utf-8') + +# add fast guard against unscoped global dataset scans +insert_after='''async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):\n scanner = KeywordScanner(db)\n\n''' +if insert_after not in t: + raise SystemExit('run_scan header block not found') +if 'Select at least one dataset' not in t: + guard=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:\n raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")\n\n''' + t=t.replace(insert_after, insert_after+guard) + +old=''' if missing: + missing_entries: list[dict] = [] + for dataset_id in missing: + partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids) + keyword_scan_cache.put(dataset_id, partial) + missing_entries.append({"result": partial, "built_at": None}) + + merged = _merge_cached_results( + cached_entries + missing_entries, + allowed_theme_names if body.theme_ids else None, + ) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": len(cached_entries) > 0, + "cache_status": "partial" if cached_entries else "miss", + "cached_at": merged["cached_at"], + } +''' +new=''' if missing: + partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids) + merged = _merge_cached_results( + cached_entries + [{"result": partial, "built_at": None}], + allowed_theme_names if body.theme_ids else None, + ) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": len(cached_entries) > 0, + "cache_status": "partial" if cached_entries else "miss", + "cached_at": merged["cached_at"], + } +''' +if old not in t: + raise SystemExit('partial-cache missing block not found') +t=t.replace(old,new) + +p.write_text(t,encoding='utf-8') +print('hardened keywords scan scope + optimized missing-cache path') diff --git a/_aup_reduce_budget.py b/_aup_reduce_budget.py new file mode 100644 index 0000000..22be366 --- /dev/null +++ b/_aup_reduce_budget.py @@ -0,0 +1,18 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=p.read_text(encoding='utf-8') +old=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field( + default=300000, + description="Global row budget for a single AUP scan request (0 = unlimited)", + ) +''' +new=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field( + default=120000, + description="Global row budget for a single AUP scan request (0 = unlimited)", + ) +''' +if old not in t: + raise SystemExit('SCANNER_MAX_ROWS_PER_SCAN block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('reduced SCANNER_MAX_ROWS_PER_SCAN default to 120000') diff --git a/_aup_update_grid_columns.py b/_aup_update_grid_columns.py new file mode 100644 index 0000000..7d58804 --- /dev/null +++ b/_aup_update_grid_columns.py @@ -0,0 +1,42 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx') +t=p.read_text(encoding='utf-8') +old='''const RESULT_COLUMNS: GridColDef[] = [ + { + field: 'theme_name', headerName: 'Theme', width: 140, + renderCell: (params) => ( + + ), + }, + { field: 'keyword', headerName: 'Keyword', width: 140 }, + { field: 'source_type', headerName: 'Source', width: 120 }, + { field: 'dataset_name', headerName: 'Dataset', width: 150 }, + { field: 'field', headerName: 'Field', width: 130 }, + { field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 }, + { field: 'row_index', headerName: 'Row #', width: 80, type: 'number' }, +]; +''' +new='''const RESULT_COLUMNS: GridColDef[] = [ + { + field: 'theme_name', headerName: 'Theme', width: 140, + renderCell: (params) => ( + + ), + }, + { field: 'keyword', headerName: 'Keyword', width: 140 }, + { field: 'dataset_name', headerName: 'Dataset', width: 170 }, + { field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' }, + { field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' }, + { field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 }, + { field: 'field', headerName: 'Field', width: 130 }, + { field: 'source_type', headerName: 'Source', width: 120 }, + { field: 'row_index', headerName: 'Row #', width: 90, type: 'number' }, +]; +''' +if old not in t: + raise SystemExit('RESULT_COLUMNS block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated AUP results grid columns with dataset/hostname/user/matched value focus') diff --git a/_edit_aup.py b/_edit_aup.py new file mode 100644 index 0000000..7466c9c --- /dev/null +++ b/_edit_aup.py @@ -0,0 +1,40 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx') +t=p.read_text(encoding='utf-8') +t=t.replace(' const [scanHunts, setScanHunts] = useState(true);',' const [scanHunts, setScanHunts] = useState(false);') +t=t.replace(' const [scanAnnotations, setScanAnnotations] = useState(true);',' const [scanAnnotations, setScanAnnotations] = useState(false);') +t=t.replace(' const [scanMessages, setScanMessages] = useState(true);',' const [scanMessages, setScanMessages] = useState(false);') +t=t.replace(' scan_messages: scanMessages,\n });',' scan_messages: scanMessages,\n prefer_cache: true,\n });') +# add cache chip in summary alert +old=''' {scanResult && ( + 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}> + {scanResult.total_hits} hits across{' '} + {scanResult.rows_scanned} rows |{' '} + {scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned + + )} +''' +new=''' {scanResult && ( + 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}> + {scanResult.total_hits} hits across{' '} + {scanResult.rows_scanned} rows |{' '} + {scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned + {scanResult.cache_status && ( + + )} + + )} +''' +if old in t: + t=t.replace(old,new) +else: + print('warning: summary block not replaced') + +p.write_text(t,encoding='utf-8') +print('updated AUPScanner.tsx') diff --git a/_edit_client.py b/_edit_client.py new file mode 100644 index 0000000..bb189fb --- /dev/null +++ b/_edit_client.py @@ -0,0 +1,36 @@ +from pathlib import Path +import re +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts') +t=p.read_text(encoding='utf-8') +# Add HuntProgress interface after Hunt interface +if 'export interface HuntProgress' not in t: + insert = '''export interface HuntProgress { + hunt_id: string; + status: 'idle' | 'processing' | 'ready'; + progress_percent: number; + dataset_total: number; + dataset_completed: number; + dataset_processing: number; + dataset_errors: number; + active_jobs: number; + queued_jobs: number; + network_status: 'none' | 'building' | 'ready'; + stages: Record; +} + +''' + t=t.replace('export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n', 'export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n'+insert) + +# Add hunts.progress method +if 'progress: (id: string)' not in t: + t=t.replace(" delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n};", " delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n progress: (id: string) => api(`/api/hunts/${id}/progress`),\n};") + +# Extend ScanResponse +if 'cache_used?: boolean' not in t: + t=t.replace('export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n}\n', 'export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n cache_used?: boolean; cache_status?: string; cached_at?: string | null;\n}\n') + +# Extend keywords.scan opts +t=t.replace(' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n }) =>', ' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n prefer_cache?: boolean; force_rescan?: boolean;\n }) =>') + +p.write_text(t,encoding='utf-8') +print('updated client.ts') diff --git a/_edit_config_reconcile.py b/_edit_config_reconcile.py new file mode 100644 index 0000000..e80382b --- /dev/null +++ b/_edit_config_reconcile.py @@ -0,0 +1,20 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=p.read_text(encoding='utf-8') +anchor=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field( + default=25, description="Max unprocessed datasets to enqueue at startup" + ) +''' +insert=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field( + default=25, description="Max unprocessed datasets to enqueue at startup" + ) + STARTUP_RECONCILE_STALE_TASKS: bool = Field( + default=True, + description="Mark stale queued/running processing tasks as failed on startup", + ) +''' +if anchor not in t: + raise SystemExit('startup anchor not found') +t=t.replace(anchor,insert) +p.write_text(t,encoding='utf-8') +print('updated config with STARTUP_RECONCILE_STALE_TASKS') diff --git a/_edit_datasets.py b/_edit_datasets.py new file mode 100644 index 0000000..9b45f58 --- /dev/null +++ b/_edit_datasets.py @@ -0,0 +1,39 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py') +t=p.read_text(encoding='utf-8') +if 'from app.services.scanner import keyword_scan_cache' not in t: + t=t.replace('from app.services.host_inventory import inventory_cache','from app.services.host_inventory import inventory_cache\nfrom app.services.scanner import keyword_scan_cache') +old='''@router.delete( + "/{dataset_id}", + summary="Delete a dataset", +) +async def delete_dataset( + dataset_id: str, + db: AsyncSession = Depends(get_db), +): + repo = DatasetRepository(db) + deleted = await repo.delete_dataset(dataset_id) + if not deleted: + raise HTTPException(status_code=404, detail="Dataset not found") + return {"message": "Dataset deleted", "id": dataset_id} +''' +new='''@router.delete( + "/{dataset_id}", + summary="Delete a dataset", +) +async def delete_dataset( + dataset_id: str, + db: AsyncSession = Depends(get_db), +): + repo = DatasetRepository(db) + deleted = await repo.delete_dataset(dataset_id) + if not deleted: + raise HTTPException(status_code=404, detail="Dataset not found") + keyword_scan_cache.invalidate_dataset(dataset_id) + return {"message": "Dataset deleted", "id": dataset_id} +''' +if old not in t: + raise SystemExit('delete block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated datasets.py') diff --git a/_edit_datasets_tasks.py b/_edit_datasets_tasks.py new file mode 100644 index 0000000..7521ad2 --- /dev/null +++ b/_edit_datasets_tasks.py @@ -0,0 +1,110 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py') +t=p.read_text(encoding='utf-8') +if 'ProcessingTask' not in t: + t=t.replace('from app.db.models import', 'from app.db.models import ProcessingTask\n# from app.db.models import') + t=t.replace('from app.services.scanner import keyword_scan_cache','from app.services.scanner import keyword_scan_cache') +# clean import replacement to proper single line +if '# from app.db.models import' in t: + t=t.replace('from app.db.models import ProcessingTask\n# from app.db.models import', 'from app.db.models import ProcessingTask') + +old=''' # 1. AI Triage (chains to HOST_PROFILE automatically on completion) + job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id) + jobs_queued.append("triage") + + # 2. Anomaly detection (embedding-based outlier detection) + job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id) + jobs_queued.append("anomaly") + + # 3. AUP keyword scan + job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id) + jobs_queued.append("keyword_scan") + + # 4. IOC extraction + job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id) + jobs_queued.append("ioc_extract") + + # 5. Host inventory (network map) - requires hunt_id + if hunt_id: + inventory_cache.invalidate(hunt_id) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + jobs_queued.append("host_inventory") +''' +new=''' task_rows: list[ProcessingTask] = [] + + # 1. AI Triage (chains to HOST_PROFILE automatically on completion) + triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id) + jobs_queued.append("triage") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=triage_job.id, + stage="triage", + status="queued", + progress=0.0, + message="Queued", + )) + + # 2. Anomaly detection (embedding-based outlier detection) + anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id) + jobs_queued.append("anomaly") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=anomaly_job.id, + stage="anomaly", + status="queued", + progress=0.0, + message="Queued", + )) + + # 3. AUP keyword scan + kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id) + jobs_queued.append("keyword_scan") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=kw_job.id, + stage="keyword_scan", + status="queued", + progress=0.0, + message="Queued", + )) + + # 4. IOC extraction + ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id) + jobs_queued.append("ioc_extract") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=ioc_job.id, + stage="ioc_extract", + status="queued", + progress=0.0, + message="Queued", + )) + + # 5. Host inventory (network map) - requires hunt_id + if hunt_id: + inventory_cache.invalidate(hunt_id) + inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + jobs_queued.append("host_inventory") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=inv_job.id, + stage="host_inventory", + status="queued", + progress=0.0, + message="Queued", + )) + + if task_rows: + db.add_all(task_rows) + await db.flush() +''' +if old not in t: + raise SystemExit('queue block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated datasets upload queue + processing tasks') diff --git a/_edit_hunts.py b/_edit_hunts.py new file mode 100644 index 0000000..2286ef1 --- /dev/null +++ b/_edit_hunts.py @@ -0,0 +1,254 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py') +new='''"""API routes for hunt management.""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import Hunt, Dataset +from app.services.job_queue import job_queue +from app.services.host_inventory import inventory_cache + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/hunts", tags=["hunts"]) + + +class HuntCreate(BaseModel): + name: str = Field(..., max_length=256) + description: str | None = None + + +class HuntUpdate(BaseModel): + name: str | None = None + description: str | None = None + status: str | None = None + + +class HuntResponse(BaseModel): + id: str + name: str + description: str | None + status: str + owner_id: str | None + created_at: str + updated_at: str + dataset_count: int = 0 + hypothesis_count: int = 0 + + +class HuntListResponse(BaseModel): + hunts: list[HuntResponse] + total: int + + +class HuntProgressResponse(BaseModel): + hunt_id: str + status: str + progress_percent: float + dataset_total: int + dataset_completed: int + dataset_processing: int + dataset_errors: int + active_jobs: int + queued_jobs: int + network_status: str + stages: dict + + +@router.post("", response_model=HuntResponse, summary="Create a new hunt") +async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)): + hunt = Hunt(name=body.name, description=body.description) + db.add(hunt) + await db.flush() + return HuntResponse( + id=hunt.id, + name=hunt.name, + description=hunt.description, + status=hunt.status, + owner_id=hunt.owner_id, + created_at=hunt.created_at.isoformat(), + updated_at=hunt.updated_at.isoformat(), + ) + + +@router.get("", response_model=HuntListResponse, summary="List hunts") +async def list_hunts( + status: str | None = Query(None), + limit: int = Query(50, ge=1, le=500), + offset: int = Query(0, ge=0), + db: AsyncSession = Depends(get_db), +): + stmt = select(Hunt).order_by(Hunt.updated_at.desc()) + if status: + stmt = stmt.where(Hunt.status == status) + stmt = stmt.limit(limit).offset(offset) + result = await db.execute(stmt) + hunts = result.scalars().all() + + count_stmt = select(func.count(Hunt.id)) + if status: + count_stmt = count_stmt.where(Hunt.status == status) + total = (await db.execute(count_stmt)).scalar_one() + + return HuntListResponse( + hunts=[ + HuntResponse( + id=h.id, + name=h.name, + description=h.description, + status=h.status, + owner_id=h.owner_id, + created_at=h.created_at.isoformat(), + updated_at=h.updated_at.isoformat(), + dataset_count=len(h.datasets) if h.datasets else 0, + hypothesis_count=len(h.hypotheses) if h.hypotheses else 0, + ) + for h in hunts + ], + total=total, + ) + + +@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details") +async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(Hunt).where(Hunt.id == hunt_id)) + hunt = result.scalar_one_or_none() + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + return HuntResponse( + id=hunt.id, + name=hunt.name, + description=hunt.description, + status=hunt.status, + owner_id=hunt.owner_id, + created_at=hunt.created_at.isoformat(), + updated_at=hunt.updated_at.isoformat(), + dataset_count=len(hunt.datasets) if hunt.datasets else 0, + hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0, + ) + + +@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress") +async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)): + hunt = await db.get(Hunt, hunt_id) + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + + ds_rows = await db.execute( + select(Dataset.id, Dataset.processing_status) + .where(Dataset.hunt_id == hunt_id) + ) + datasets = ds_rows.all() + dataset_ids = {row[0] for row in datasets} + + dataset_total = len(datasets) + dataset_completed = sum(1 for _, st in datasets if st == "completed") + dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors") + dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors) + + jobs = job_queue.list_jobs(limit=5000) + relevant_jobs = [ + j for j in jobs + if j.get("params", {}).get("hunt_id") == hunt_id + or j.get("params", {}).get("dataset_id") in dataset_ids + ] + active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running") + queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued") + + if inventory_cache.get(hunt_id) is not None: + network_status = "ready" + network_ratio = 1.0 + elif inventory_cache.is_building(hunt_id): + network_status = "building" + network_ratio = 0.5 + else: + network_status = "none" + network_ratio = 0.0 + + dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0 + overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15)) + progress_percent = round(overall_ratio * 100.0, 1) + + status = "ready" + if dataset_total == 0: + status = "idle" + elif progress_percent < 100: + status = "processing" + + stages = { + "datasets": { + "total": dataset_total, + "completed": dataset_completed, + "processing": dataset_processing, + "errors": dataset_errors, + "percent": round(dataset_ratio * 100.0, 1), + }, + "network": { + "status": network_status, + "percent": round(network_ratio * 100.0, 1), + }, + "jobs": { + "active": active_jobs, + "queued": queued_jobs, + "total_seen": len(relevant_jobs), + }, + } + + return HuntProgressResponse( + hunt_id=hunt_id, + status=status, + progress_percent=progress_percent, + dataset_total=dataset_total, + dataset_completed=dataset_completed, + dataset_processing=dataset_processing, + dataset_errors=dataset_errors, + active_jobs=active_jobs, + queued_jobs=queued_jobs, + network_status=network_status, + stages=stages, + ) + + +@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt") +async def update_hunt( + hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db) +): + result = await db.execute(select(Hunt).where(Hunt.id == hunt_id)) + hunt = result.scalar_one_or_none() + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + if body.name is not None: + hunt.name = body.name + if body.description is not None: + hunt.description = body.description + if body.status is not None: + hunt.status = body.status + await db.flush() + return HuntResponse( + id=hunt.id, + name=hunt.name, + description=hunt.description, + status=hunt.status, + owner_id=hunt.owner_id, + created_at=hunt.created_at.isoformat(), + updated_at=hunt.updated_at.isoformat(), + ) + + +@router.delete("/{hunt_id}", summary="Delete a hunt") +async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(Hunt).where(Hunt.id == hunt_id)) + hunt = result.scalar_one_or_none() + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + await db.delete(hunt) + return {"message": "Hunt deleted", "id": hunt_id} +''' +p.write_text(new,encoding='utf-8') +print('updated hunts.py') diff --git a/_edit_hunts_progress_tasks.py b/_edit_hunts_progress_tasks.py new file mode 100644 index 0000000..64f065e --- /dev/null +++ b/_edit_hunts_progress_tasks.py @@ -0,0 +1,102 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py') +t=p.read_text(encoding='utf-8') +if 'ProcessingTask' not in t: + t=t.replace('from app.db.models import Hunt, Dataset','from app.db.models import Hunt, Dataset, ProcessingTask') + +old=''' jobs = job_queue.list_jobs(limit=5000) + relevant_jobs = [ + j for j in jobs + if j.get("params", {}).get("hunt_id") == hunt_id + or j.get("params", {}).get("dataset_id") in dataset_ids + ] + active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running") + queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued") + + if inventory_cache.get(hunt_id) is not None: +''' +new=''' jobs = job_queue.list_jobs(limit=5000) + relevant_jobs = [ + j for j in jobs + if j.get("params", {}).get("hunt_id") == hunt_id + or j.get("params", {}).get("dataset_id") in dataset_ids + ] + active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running") + queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued") + + task_rows = await db.execute( + select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress) + .where(ProcessingTask.hunt_id == hunt_id) + ) + tasks = task_rows.all() + + task_total = len(tasks) + task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled")) + task_running = sum(1 for _, st, _ in tasks if st == "running") + task_queued = sum(1 for _, st, _ in tasks if st == "queued") + task_ratio = (task_done / task_total) if task_total > 0 else None + + active_jobs = max(active_jobs_mem, task_running) + queued_jobs = max(queued_jobs_mem, task_queued) + + stage_rollup: dict[str, dict] = {} + for stage, status, progress in tasks: + bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0}) + bucket["total"] += 1 + if status in ("completed", "failed", "cancelled"): + bucket["done"] += 1 + elif status == "running": + bucket["running"] += 1 + elif status == "queued": + bucket["queued"] += 1 + bucket["progress_sum"] += float(progress or 0.0) + + for stage_name, bucket in stage_rollup.items(): + total = max(1, bucket["total"]) + bucket["percent"] = round(bucket["progress_sum"] / total, 1) + + if inventory_cache.get(hunt_id) is not None: +''' +if old not in t: + raise SystemExit('job block not found') +t=t.replace(old,new) + +old2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0 + overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15)) + progress_percent = round(overall_ratio * 100.0, 1) +''' +new2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0 + if task_ratio is None: + overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15)) + else: + overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15)) + progress_percent = round(overall_ratio * 100.0, 1) +''' +if old2 not in t: + raise SystemExit('ratio block not found') +t=t.replace(old2,new2) + +old3=''' "jobs": { + "active": active_jobs, + "queued": queued_jobs, + "total_seen": len(relevant_jobs), + }, + } +''' +new3=''' "jobs": { + "active": active_jobs, + "queued": queued_jobs, + "total_seen": len(relevant_jobs), + "task_total": task_total, + "task_done": task_done, + "task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None, + }, + "task_stages": stage_rollup, + } +''' +if old3 not in t: + raise SystemExit('stages jobs block not found') +t=t.replace(old3,new3) + +p.write_text(t,encoding='utf-8') +print('updated hunt progress to merge persistent processing tasks') diff --git a/_edit_job_queue.py b/_edit_job_queue.py new file mode 100644 index 0000000..84548ad --- /dev/null +++ b/_edit_job_queue.py @@ -0,0 +1,46 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py') +t=p.read_text(encoding='utf-8') +old='''async def _handle_keyword_scan(job: Job): + """AUP keyword scan handler.""" + from app.db import async_session_factory + from app.services.scanner import KeywordScanner + + dataset_id = job.params.get("dataset_id") + job.message = f"Running AUP keyword scan on dataset {dataset_id}" + + async with async_session_factory() as db: + scanner = KeywordScanner(db) + result = await scanner.scan(dataset_ids=[dataset_id]) + + hits = result.get("total_hits", 0) + job.message = f"Keyword scan complete: {hits} hits" + logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows") + return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)} +''' +new='''async def _handle_keyword_scan(job: Job): + """AUP keyword scan handler.""" + from app.db import async_session_factory + from app.services.scanner import KeywordScanner, keyword_scan_cache + + dataset_id = job.params.get("dataset_id") + job.message = f"Running AUP keyword scan on dataset {dataset_id}" + + async with async_session_factory() as db: + scanner = KeywordScanner(db) + result = await scanner.scan(dataset_ids=[dataset_id]) + + # Cache dataset-only result for fast API reuse + if dataset_id: + keyword_scan_cache.put(dataset_id, result) + + hits = result.get("total_hits", 0) + job.message = f"Keyword scan complete: {hits} hits" + logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows") + return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)} +''' +if old not in t: + raise SystemExit('target block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated job_queue keyword scan handler') diff --git a/_edit_jobqueue_reconcile.py b/_edit_jobqueue_reconcile.py new file mode 100644 index 0000000..1730f09 --- /dev/null +++ b/_edit_jobqueue_reconcile.py @@ -0,0 +1,13 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py') +t=p.read_text(encoding='utf-8') +marker='''def register_all_handlers(): + """Register all job handlers and completion callbacks.""" +''' +ins='''\n\nasync def reconcile_stale_processing_tasks() -> int:\n """Mark queued/running processing tasks from prior runs as failed."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n now = datetime.now(timezone.utc)\n async with async_session_factory() as db:\n result = await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.status.in_([\"queued\", \"running\"]))\n .values(\n status=\"failed\",\n error=\"Recovered after service restart before task completion\",\n message=\"Recovered stale task after restart\",\n completed_at=now,\n )\n )\n await db.commit()\n updated = int(result.rowcount or 0)\n\n if updated:\n logger.warning(\n \"Reconciled %d stale processing tasks (queued/running -> failed) during startup\",\n updated,\n )\n return updated\n except Exception as e:\n logger.warning(f\"Failed to reconcile stale processing tasks: {e}\")\n return 0\n\n\n''' +if ins.strip() not in t: + if marker not in t: + raise SystemExit('register marker not found') + t=t.replace(marker,ins+marker) +p.write_text(t,encoding='utf-8') +print('added reconcile_stale_processing_tasks to job_queue') diff --git a/_edit_jobqueue_sync.py b/_edit_jobqueue_sync.py new file mode 100644 index 0000000..c6ff0ca --- /dev/null +++ b/_edit_jobqueue_sync.py @@ -0,0 +1,64 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py') +t=p.read_text(encoding='utf-8') +ins='''\n\nasync def _sync_processing_task(job: Job):\n """Persist latest job state into processing_tasks (if linked by job_id)."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n values = {\n "status": job.status.value,\n "progress": float(job.progress),\n "message": job.message,\n "error": job.error,\n }\n if job.started_at:\n values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)\n if job.completed_at:\n values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)\n\n async with async_session_factory() as db:\n await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.job_id == job.id)\n .values(**values)\n )\n await db.commit()\n except Exception as e:\n logger.warning(f"Failed to sync processing task for job {job.id}: {e}")\n''' +marker='\n\n# -- Singleton + job handlers --\n' +if ins.strip() not in t: + t=t.replace(marker, ins+marker) + +old=''' job.status = JobStatus.RUNNING + job.started_at = time.time() + job.message = "Running..." + logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") + + try: +''' +new=''' job.status = JobStatus.RUNNING + job.started_at = time.time() + if job.progress <= 0: + job.progress = 5.0 + job.message = "Running..." + await _sync_processing_task(job) + logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") + + try: +''' +if old not in t: + raise SystemExit('worker running block not found') +t=t.replace(old,new) + +old2=''' job.completed_at = time.time() + logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms") + except Exception as e: + if not job.is_cancelled: + job.status = JobStatus.FAILED + job.error = str(e) + job.message = f"Failed: {e}" + job.completed_at = time.time() + logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True) + + # Fire completion callbacks +''' +new2=''' job.completed_at = time.time() + logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms") + except Exception as e: + if not job.is_cancelled: + job.status = JobStatus.FAILED + job.error = str(e) + job.message = f"Failed: {e}" + job.completed_at = time.time() + logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True) + + if job.is_cancelled and not job.completed_at: + job.completed_at = time.time() + + await _sync_processing_task(job) + + # Fire completion callbacks +''' +if old2 not in t: + raise SystemExit('worker completion block not found') +t=t.replace(old2,new2) + +p.write_text(t, encoding='utf-8') +print('updated job_queue persistent task syncing') diff --git a/_edit_jobqueue_triage_task.py b/_edit_jobqueue_triage_task.py new file mode 100644 index 0000000..22b6694 --- /dev/null +++ b/_edit_jobqueue_triage_task.py @@ -0,0 +1,39 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py') +t=p.read_text(encoding='utf-8') +old=''' if hunt_id: + job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id) + logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}") + except Exception as e: +''' +new=''' if hunt_id: + hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id) + try: + from sqlalchemy import select + from app.db.models import ProcessingTask + async with async_session_factory() as db: + existing = await db.execute( + select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id) + ) + if existing.first() is None: + db.add(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset_id, + job_id=hp_job.id, + stage="host_profile", + status="queued", + progress=0.0, + message="Queued", + )) + await db.commit() + except Exception as persist_err: + logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}") + + logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}") + except Exception as e: +''' +if old not in t: + raise SystemExit('triage chain block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated triage chain to persist host_profile task row') diff --git a/_edit_keywords.py b/_edit_keywords.py new file mode 100644 index 0000000..fba2c20 --- /dev/null +++ b/_edit_keywords.py @@ -0,0 +1,321 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py') +new_text='''"""API routes for AUP keyword themes, keyword CRUD, and scanning.""" + +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import KeywordTheme, Keyword +from app.services.scanner import KeywordScanner, keyword_scan_cache + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/keywords", tags=["keywords"]) + + +class ThemeCreate(BaseModel): + name: str = Field(..., min_length=1, max_length=128) + color: str = Field(default="#9e9e9e", max_length=16) + enabled: bool = True + + +class ThemeUpdate(BaseModel): + name: str | None = None + color: str | None = None + enabled: bool | None = None + + +class KeywordOut(BaseModel): + id: int + theme_id: str + value: str + is_regex: bool + created_at: str + + +class ThemeOut(BaseModel): + id: str + name: str + color: str + enabled: bool + is_builtin: bool + created_at: str + keyword_count: int + keywords: list[KeywordOut] + + +class ThemeListResponse(BaseModel): + themes: list[ThemeOut] + total: int + + +class KeywordCreate(BaseModel): + value: str = Field(..., min_length=1, max_length=256) + is_regex: bool = False + + +class KeywordBulkCreate(BaseModel): + values: list[str] = Field(..., min_items=1) + is_regex: bool = False + + +class ScanRequest(BaseModel): + dataset_ids: list[str] | None = None + theme_ids: list[str] | None = None + scan_hunts: bool = False + scan_annotations: bool = False + scan_messages: bool = False + prefer_cache: bool = True + force_rescan: bool = False + + +class ScanHit(BaseModel): + theme_name: str + theme_color: str + keyword: str + source_type: str + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None + + +class ScanResponse(BaseModel): + total_hits: int + hits: list[ScanHit] + themes_scanned: int + keywords_scanned: int + rows_scanned: int + cache_used: bool = False + cache_status: str = "miss" + cached_at: str | None = None + + +def _theme_to_out(t: KeywordTheme) -> ThemeOut: + return ThemeOut( + id=t.id, + name=t.name, + color=t.color, + enabled=t.enabled, + is_builtin=t.is_builtin, + created_at=t.created_at.isoformat(), + keyword_count=len(t.keywords), + keywords=[ + KeywordOut( + id=k.id, + theme_id=k.theme_id, + value=k.value, + is_regex=k.is_regex, + created_at=k.created_at.isoformat(), + ) + for k in t.keywords + ], + ) + + +def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict: + hits: list[dict] = [] + total_rows = 0 + cached_at: str | None = None + + for entry in entries: + result = entry["result"] + total_rows += int(result.get("rows_scanned", 0) or 0) + if entry.get("built_at"): + if not cached_at or entry["built_at"] > cached_at: + cached_at = entry["built_at"] + for h in result.get("hits", []): + if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names: + continue + hits.append(h) + + return { + "total_hits": len(hits), + "hits": hits, + "rows_scanned": total_rows, + "cached_at": cached_at, + } + + +@router.get("/themes", response_model=ThemeListResponse) +async def list_themes(db: AsyncSession = Depends(get_db)): + result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name)) + themes = result.scalars().all() + return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes)) + + +@router.post("/themes", response_model=ThemeOut, status_code=201) +async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)): + exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name)) + if exists: + raise HTTPException(409, f"Theme '{body.name}' already exists") + theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled) + db.add(theme) + await db.flush() + await db.refresh(theme) + keyword_scan_cache.clear() + return _theme_to_out(theme) + + +@router.put("/themes/{theme_id}", response_model=ThemeOut) +async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)): + theme = await db.get(KeywordTheme, theme_id) + if not theme: + raise HTTPException(404, "Theme not found") + if body.name is not None: + dup = await db.scalar( + select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id) + ) + if dup: + raise HTTPException(409, f"Theme '{body.name}' already exists") + theme.name = body.name + if body.color is not None: + theme.color = body.color + if body.enabled is not None: + theme.enabled = body.enabled + await db.flush() + await db.refresh(theme) + keyword_scan_cache.clear() + return _theme_to_out(theme) + + +@router.delete("/themes/{theme_id}", status_code=204) +async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)): + theme = await db.get(KeywordTheme, theme_id) + if not theme: + raise HTTPException(404, "Theme not found") + await db.delete(theme) + keyword_scan_cache.clear() + + +@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201) +async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)): + theme = await db.get(KeywordTheme, theme_id) + if not theme: + raise HTTPException(404, "Theme not found") + kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex) + db.add(kw) + await db.flush() + await db.refresh(kw) + keyword_scan_cache.clear() + return KeywordOut( + id=kw.id, theme_id=kw.theme_id, value=kw.value, + is_regex=kw.is_regex, created_at=kw.created_at.isoformat(), + ) + + +@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201) +async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)): + theme = await db.get(KeywordTheme, theme_id) + if not theme: + raise HTTPException(404, "Theme not found") + added = 0 + for val in body.values: + val = val.strip() + if not val: + continue + db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex)) + added += 1 + await db.flush() + keyword_scan_cache.clear() + return {"added": added, "theme_id": theme_id} + + +@router.delete("/keywords/{keyword_id}", status_code=204) +async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)): + kw = await db.get(Keyword, keyword_id) + if not kw: + raise HTTPException(404, "Keyword not found") + await db.delete(kw) + keyword_scan_cache.clear() + + +@router.post("/scan", response_model=ScanResponse) +async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): + scanner = KeywordScanner(db) + + can_use_cache = ( + body.prefer_cache + and not body.force_rescan + and bool(body.dataset_ids) + and not body.scan_hunts + and not body.scan_annotations + and not body.scan_messages + ) + + if can_use_cache: + themes = await scanner._load_themes(body.theme_ids) + allowed_theme_names = {t.name for t in themes} + keywords_scanned = sum(len(theme.keywords) for theme in themes) + + cached_entries: list[dict] = [] + missing: list[str] = [] + for dataset_id in (body.dataset_ids or []): + entry = keyword_scan_cache.get(dataset_id) + if not entry: + missing.append(dataset_id) + continue + cached_entries.append({"result": entry.result, "built_at": entry.built_at}) + + if not missing and cached_entries: + merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": True, + "cache_status": "hit", + "cached_at": merged["cached_at"], + } + + result = await scanner.scan( + dataset_ids=body.dataset_ids, + theme_ids=body.theme_ids, + scan_hunts=body.scan_hunts, + scan_annotations=body.scan_annotations, + scan_messages=body.scan_messages, + ) + + return { + **result, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } + + +@router.get("/scan/quick", response_model=ScanResponse) +async def quick_scan( + dataset_id: str = Query(..., description="Dataset to scan"), + db: AsyncSession = Depends(get_db), +): + entry = keyword_scan_cache.get(dataset_id) + if entry is not None: + result = entry.result + return { + **result, + "cache_used": True, + "cache_status": "hit", + "cached_at": entry.built_at, + } + + scanner = KeywordScanner(db) + result = await scanner.scan(dataset_ids=[dataset_id]) + keyword_scan_cache.put(dataset_id, result) + return { + **result, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } +''' +p.write_text(new_text,encoding='utf-8') +print('updated keywords.py') diff --git a/_edit_main_reconcile.py b/_edit_main_reconcile.py new file mode 100644 index 0000000..bf975b6 --- /dev/null +++ b/_edit_main_reconcile.py @@ -0,0 +1,31 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/main.py') +t=p.read_text(encoding='utf-8') +old=''' # Start job queue + from app.services.job_queue import job_queue, register_all_handlers, JobType + register_all_handlers() + await job_queue.start() + logger.info("Job queue started (%d workers)", job_queue._max_workers) +''' +new=''' # Start job queue + from app.services.job_queue import ( + job_queue, + register_all_handlers, + reconcile_stale_processing_tasks, + JobType, + ) + + if settings.STARTUP_RECONCILE_STALE_TASKS: + reconciled = await reconcile_stale_processing_tasks() + if reconciled: + logger.info("Startup reconciliation marked %d stale tasks", reconciled) + + register_all_handlers() + await job_queue.start() + logger.info("Job queue started (%d workers)", job_queue._max_workers) +''' +if old not in t: + raise SystemExit('startup queue block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('wired startup reconciliation in main lifespan') diff --git a/_edit_models_processing.py b/_edit_models_processing.py new file mode 100644 index 0000000..7da6284 --- /dev/null +++ b/_edit_models_processing.py @@ -0,0 +1,45 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/db/models.py') +t=p.read_text(encoding='utf-8') +if 'class ProcessingTask(Base):' in t: + print('processing task model already exists') + raise SystemExit(0) +insert=''' + +# -- Persistent Processing Tasks (Phase 2) --- + +class ProcessingTask(Base): + __tablename__ = "processing_tasks" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + hunt_id: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True + ) + dataset_id: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True + ) + job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True) + stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + status: Mapped[str] = mapped_column(String(20), default="queued", index=True) + progress: Mapped[float] = mapped_column(Float, default=0.0) + message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + __table_args__ = ( + Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"), + Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"), + ) +''' +# insert before Playbook section +marker='\n\n# -- Playbook / Investigation Templates (Feature 3) ---\n' +if marker not in t: + raise SystemExit('marker not found for insertion') +t=t.replace(marker, insert+marker) +p.write_text(t,encoding='utf-8') +print('added ProcessingTask model') diff --git a/_edit_networkmap_hit.py b/_edit_networkmap_hit.py new file mode 100644 index 0000000..20c0b8c --- /dev/null +++ b/_edit_networkmap_hit.py @@ -0,0 +1,59 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') +insert=''' +function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean { + const fontSize = Math.max(9, Math.round(12 / vp.scale)); + const approxCharW = Math.max(5, fontSize * 0.58); + const line1 = node.label || ''; + const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : ''; + const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0); + const px = 5, py = 2; + const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2; + const lx = node.x, ly = node.y - node.radius - 6; + const rx = lx - tw / 2 - px; + const ry = ly - totalH; + const rw = tw + px * 2; + const rh = totalH; + return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh); +} + +''' +if 'function isPointOnNodeLabel' not in t: + t=t.replace('// == Hit-test =============================================================\n', '// == Hit-test =============================================================\n'+insert) + +old='''function hitTest( + graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, +): GNode | null { + const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp); + for (const n of graph.nodes) { + const dx = n.x - wx, dy = n.y - wy; + if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n; + } + return null; +} +''' +new='''function hitTest( + graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, +): GNode | null { + const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp); + + // Node-circle hit has priority + for (const n of graph.nodes) { + const dx = n.x - wx, dy = n.y - wy; + if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n; + } + + // Then label hit (so clicking text works too) + for (const n of graph.nodes) { + if (isPointOnNodeLabel(n, wx, wy, vp)) return n; + } + + return null; +} +''' +if old not in t: + raise SystemExit('hitTest block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated NetworkMap hit-test for labels') diff --git a/_edit_scanner.py b/_edit_scanner.py new file mode 100644 index 0000000..af6fc1b --- /dev/null +++ b/_edit_scanner.py @@ -0,0 +1,272 @@ +from pathlib import Path + +p = Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py') +text = p.read_text(encoding='utf-8') +new_text = '''"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and +messages for keyword matches. + +Scanning is done in Python (not SQL LIKE on JSON columns) for portability +across SQLite / PostgreSQL and to provide per-cell match context. +""" + +import logging +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import ( + KeywordTheme, + DatasetRow, + Dataset, + Hunt, + Annotation, + Message, +) + +logger = logging.getLogger(__name__) + +BATCH_SIZE = 200 + + +@dataclass +class ScanHit: + theme_name: str + theme_color: str + keyword: str + source_type: str # dataset_row | hunt | annotation | message + source_id: str | int + field: str + matched_value: str + row_index: int | None = None + dataset_name: str | None = None + + +@dataclass +class ScanResult: + total_hits: int = 0 + hits: list[ScanHit] = field(default_factory=list) + themes_scanned: int = 0 + keywords_scanned: int = 0 + rows_scanned: int = 0 + + +@dataclass +class KeywordScanCacheEntry: + dataset_id: str + result: dict + built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + +class KeywordScanCache: + """In-memory per-dataset cache for dataset-only keyword scans. + + This enables fast-path reads when users run AUP scans against datasets that + were already scanned during upload pipeline processing. + """ + + def __init__(self): + self._entries: dict[str, KeywordScanCacheEntry] = {} + + def put(self, dataset_id: str, result: dict): + self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result) + + def get(self, dataset_id: str) -> KeywordScanCacheEntry | None: + return self._entries.get(dataset_id) + + def invalidate_dataset(self, dataset_id: str): + self._entries.pop(dataset_id, None) + + def clear(self): + self._entries.clear() + + +keyword_scan_cache = KeywordScanCache() + + +class KeywordScanner: + """Scans multiple data sources for keyword/regex matches.""" + + def __init__(self, db: AsyncSession): + self.db = db + + # Public API + + async def scan( + self, + dataset_ids: list[str] | None = None, + theme_ids: list[str] | None = None, + scan_hunts: bool = False, + scan_annotations: bool = False, + scan_messages: bool = False, + ) -> dict: + """Run a full AUP scan and return dict matching ScanResponse.""" + # Load themes + keywords + themes = await self._load_themes(theme_ids) + if not themes: + return ScanResult().__dict__ + + # Pre-compile patterns per theme + patterns = self._compile_patterns(themes) + result = ScanResult( + themes_scanned=len(themes), + keywords_scanned=sum(len(kws) for kws in patterns.values()), + ) + + # Scan dataset rows + await self._scan_datasets(patterns, result, dataset_ids) + + # Scan hunts + if scan_hunts: + await self._scan_hunts(patterns, result) + + # Scan annotations + if scan_annotations: + await self._scan_annotations(patterns, result) + + # Scan messages + if scan_messages: + await self._scan_messages(patterns, result) + + result.total_hits = len(result.hits) + return { + "total_hits": result.total_hits, + "hits": [h.__dict__ for h in result.hits], + "themes_scanned": result.themes_scanned, + "keywords_scanned": result.keywords_scanned, + "rows_scanned": result.rows_scanned, + } + + # Internal + + async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]: + q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712 + if theme_ids: + q = q.where(KeywordTheme.id.in_(theme_ids)) + result = await self.db.execute(q) + return list(result.scalars().all()) + + def _compile_patterns( + self, themes: list[KeywordTheme] + ) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]: + """Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}""" + patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {} + for theme in themes: + key = (theme.id, theme.name, theme.color) + compiled = [] + for kw in theme.keywords: + try: + if kw.is_regex: + pat = re.compile(kw.value, re.IGNORECASE) + else: + pat = re.compile(re.escape(kw.value), re.IGNORECASE) + compiled.append((kw.value, pat)) + except re.error: + logger.warning("Invalid regex pattern '%s' in theme '%s', skipping", + kw.value, theme.name) + patterns[key] = compiled + return patterns + + def _match_text( + self, + text: str, + patterns: dict, + source_type: str, + source_id: str | int, + field_name: str, + hits: list[ScanHit], + row_index: int | None = None, + dataset_name: str | None = None, + ) -> None: + """Check text against all compiled patterns, append hits.""" + if not text: + return + for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items(): + for kw_value, pat in keyword_patterns: + if pat.search(text): + matched_preview = text[:200] + ("" if len(text) > 200 else "") + hits.append(ScanHit( + theme_name=theme_name, + theme_color=theme_color, + keyword=kw_value, + source_type=source_type, + source_id=source_id, + field=field_name, + matched_value=matched_preview, + row_index=row_index, + dataset_name=dataset_name, + )) + + async def _scan_datasets( + self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None + ) -> None: + """Scan dataset rows in batches.""" + ds_q = select(Dataset.id, Dataset.name) + if dataset_ids: + ds_q = ds_q.where(Dataset.id.in_(dataset_ids)) + ds_result = await self.db.execute(ds_q) + ds_map = {r[0]: r[1] for r in ds_result.fetchall()} + + if not ds_map: + return + + offset = 0 + row_q_base = select(DatasetRow).where( + DatasetRow.dataset_id.in_(list(ds_map.keys())) + ).order_by(DatasetRow.id) + + while True: + rows_result = await self.db.execute( + row_q_base.offset(offset).limit(BATCH_SIZE) + ) + rows = rows_result.scalars().all() + if not rows: + break + + for row in rows: + result.rows_scanned += 1 + data = row.data or {} + for col_name, cell_value in data.items(): + if cell_value is None: + continue + text = str(cell_value) + self._match_text( + text, patterns, "dataset_row", row.id, + col_name, result.hits, + row_index=row.row_index, + dataset_name=ds_map.get(row.dataset_id), + ) + + offset += BATCH_SIZE + import asyncio + await asyncio.sleep(0) + if len(rows) < BATCH_SIZE: + break + + async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None: + """Scan hunt names and descriptions.""" + hunts_result = await self.db.execute(select(Hunt)) + for hunt in hunts_result.scalars().all(): + self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits) + if hunt.description: + self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits) + + async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None: + """Scan annotation text.""" + ann_result = await self.db.execute(select(Annotation)) + for ann in ann_result.scalars().all(): + self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits) + + async def _scan_messages(self, patterns: dict, result: ScanResult) -> None: + """Scan conversation messages (user messages only).""" + msg_result = await self.db.execute( + select(Message).where(Message.role == "user") + ) + for msg in msg_result.scalars().all(): + self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits) +''' + +p.write_text(new_text, encoding='utf-8') +print('updated scanner.py') diff --git a/_edit_test_api.py b/_edit_test_api.py new file mode 100644 index 0000000..e8158b4 --- /dev/null +++ b/_edit_test_api.py @@ -0,0 +1,31 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_api.py') +t=p.read_text(encoding='utf-8') +insert=''' + async def test_hunt_progress(self, client): + create = await client.post("/api/hunts", json={"name": "Progress Hunt"}) + hunt_id = create.json()["id"] + + # attach one dataset so progress has scope + from tests.conftest import SAMPLE_CSV + import io + files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")} + up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files) + assert up.status_code == 200 + + res = await client.get(f"/api/hunts/{hunt_id}/progress") + assert res.status_code == 200 + body = res.json() + assert body["hunt_id"] == hunt_id + assert "progress_percent" in body + assert "dataset_total" in body + assert "network_status" in body +''' +needle=''' async def test_get_nonexistent_hunt(self, client): + resp = await client.get("/api/hunts/nonexistent-id") + assert resp.status_code == 404 +''' +if needle in t and 'test_hunt_progress' not in t: + t=t.replace(needle, needle+'\n'+insert) +p.write_text(t,encoding='utf-8') +print('updated test_api.py') diff --git a/_edit_test_keywords.py b/_edit_test_keywords.py new file mode 100644 index 0000000..a2ea565 --- /dev/null +++ b/_edit_test_keywords.py @@ -0,0 +1,32 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_keywords.py') +t=p.read_text(encoding='utf-8') +add=''' + +@pytest.mark.asyncio +async def test_quick_scan_cache_hit(client: AsyncClient): + """Second quick scan should return cache hit metadata.""" + theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"}) + tid = theme_res.json()["id"] + await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"}) + + from tests.conftest import SAMPLE_CSV + import io + files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")} + upload = await client.post("/api/datasets/upload", files=files) + ds_id = upload.json()["id"] + + first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}") + assert first.status_code == 200 + assert first.json().get("cache_status") in ("miss", "hit") + + second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}") + assert second.status_code == 200 + body = second.json() + assert body.get("cache_used") is True + assert body.get("cache_status") == "hit" +''' +if 'test_quick_scan_cache_hit' not in t: + t=t + add +p.write_text(t,encoding='utf-8') +print('updated test_keywords.py') diff --git a/_edit_upload.py b/_edit_upload.py new file mode 100644 index 0000000..f9d2af1 --- /dev/null +++ b/_edit_upload.py @@ -0,0 +1,26 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx') +t=p.read_text(encoding='utf-8') +# import useEffect +t=t.replace("import React, { useState, useCallback, useRef } from 'react';","import React, { useState, useCallback, useRef, useEffect } from 'react';") +# import HuntProgress type +t=t.replace("import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';","import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client';") +# add state +if 'const [huntProgress, setHuntProgress]' not in t: + t=t.replace(" const [huntList, setHuntList] = useState([]);\n const [huntId, setHuntId] = useState('');"," const [huntList, setHuntList] = useState([]);\n const [huntId, setHuntId] = useState('');\n const [huntProgress, setHuntProgress] = useState(null);") +# add polling effect after hunts list effect +marker=" React.useEffect(() => {\n hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});\n }, []);\n" +if marker in t and 'setInterval' not in t.split(marker,1)[1][:500]: + add='''\n useEffect(() => {\n let timer: any = null;\n let cancelled = false;\n\n const pull = async () => {\n if (!huntId) {\n if (!cancelled) setHuntProgress(null);\n return;\n }\n try {\n const p = await hunts.progress(huntId);\n if (!cancelled) setHuntProgress(p);\n } catch {\n if (!cancelled) setHuntProgress(null);\n }\n };\n\n pull();\n if (huntId) timer = setInterval(pull, 2000);\n return () => { cancelled = true; if (timer) clearInterval(timer); };\n }, [huntId, jobs.length]);\n''' + t=t.replace(marker, marker+add) + +# insert master progress UI after overall summary +insert_after=''' {overallTotal > 0 && (\n \n \n {overallDone + overallErr} / {overallTotal} files processed\n {overallErr > 0 && ` ({overallErr} failed)`}\n \n \n {overallDone + overallErr === overallTotal && overallTotal > 0 && (\n \n \n \n )}\n \n )}\n''' +add_block='''\n {huntId && huntProgress && (\n \n \n \n Master Processing Progress\n \n \n \n \n {huntProgress.progress_percent.toFixed(1)}%\n \n \n \n \n \n \n \n \n \n \n )}\n''' +if insert_after in t: + t=t.replace(insert_after, insert_after+add_block) +else: + print('warning: summary block not found') + +p.write_text(t,encoding='utf-8') +print('updated FileUpload.tsx') diff --git a/_edit_upload2.py b/_edit_upload2.py new file mode 100644 index 0000000..e1523ac --- /dev/null +++ b/_edit_upload2.py @@ -0,0 +1,42 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx') +t=p.read_text(encoding='utf-8') +marker=''' {/* Per-file progress list */} +''' +add=''' {huntId && huntProgress && ( + + + + Master Processing Progress + + + + + {huntProgress.progress_percent.toFixed(1)}% + + + + + + + + + + + )} + +''' +if marker not in t: + raise SystemExit('marker not found') +t=t.replace(marker, add+marker) +p.write_text(t,encoding='utf-8') +print('inserted master progress block') diff --git a/_enforce_scanner_budget.py b/_enforce_scanner_budget.py new file mode 100644 index 0000000..264b6e8 --- /dev/null +++ b/_enforce_scanner_budget.py @@ -0,0 +1,55 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py') +t=p.read_text(encoding='utf-8') +if 'from app.config import settings' not in t: + t=t.replace('from sqlalchemy.ext.asyncio import AsyncSession\n','from sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\n') + +old=''' import asyncio + + for ds_id, ds_name in ds_map.items(): + last_id = 0 + while True: +''' +new=''' import asyncio + + max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN)) + budget_reached = False + + for ds_id, ds_name in ds_map.items(): + if max_rows and result.rows_scanned >= max_rows: + budget_reached = True + break + + last_id = 0 + while True: + if max_rows and result.rows_scanned >= max_rows: + budget_reached = True + break +''' +if old not in t: + raise SystemExit('scanner loop block not found') +t=t.replace(old,new) + +old2=''' if len(rows) < BATCH_SIZE: + break + +''' +new2=''' if len(rows) < BATCH_SIZE: + break + + if budget_reached: + break + + if budget_reached: + logger.warning( + "AUP scan row budget reached (%d rows). Returning partial results.", + result.rows_scanned, + ) + +''' +if old2 not in t: + raise SystemExit('scanner break block not found') +t=t.replace(old2,new2,1) + +p.write_text(t,encoding='utf-8') +print('added scanner global row budget enforcement') diff --git a/_fix_aup_dep.py b/_fix_aup_dep.py new file mode 100644 index 0000000..f4b034b --- /dev/null +++ b/_fix_aup_dep.py @@ -0,0 +1,12 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx') +t=p.read_text(encoding='utf-8') +old=''' }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]); +''' +new=''' }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]); +''' +if old not in t: + raise SystemExit('runScan deps block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('fixed AUPScanner runScan dependency list') diff --git a/_fix_import_datasets.py b/_fix_import_datasets.py new file mode 100644 index 0000000..9a152dc --- /dev/null +++ b/_fix_import_datasets.py @@ -0,0 +1,7 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py') +t=p.read_text(encoding='utf-8') +if 'from app.db.models import ProcessingTask' not in t: + t=t.replace('from app.db import get_db\n', 'from app.db import get_db\nfrom app.db.models import ProcessingTask\n') +p.write_text(t, encoding='utf-8') +print('added ProcessingTask import') diff --git a/_fix_keywords_empty_guard.py b/_fix_keywords_empty_guard.py new file mode 100644 index 0000000..7abd782 --- /dev/null +++ b/_fix_keywords_empty_guard.py @@ -0,0 +1,25 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py') +t=p.read_text(encoding='utf-8') +old=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages: + raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)") + +''' +new=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages: + return { + "total_hits": 0, + "hits": [], + "themes_scanned": 0, + "keywords_scanned": 0, + "rows_scanned": 0, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } + +''' +if old not in t: + raise SystemExit('scope guard block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('adjusted empty scan guard to return fast empty result (200)') diff --git a/_fix_label_selector_networkmap.py b/_fix_label_selector_networkmap.py new file mode 100644 index 0000000..d7ad011 --- /dev/null +++ b/_fix_label_selector_networkmap.py @@ -0,0 +1,47 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') + +# Add label selector in toolbar before refresh button +insert_after=""" setSearch(e.target.value)} + sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }} + slotProps={{ + input: { + startAdornment: , + }, + }} + /> +""" +label_ctrl=""" + + Labels + + +""" +if 'label-mode-selector' not in t: + if insert_after not in t: + raise SystemExit('search block not found for label selector insertion') + t=t.replace(insert_after, insert_after+label_ctrl) + +# Fix useCallback dependency for startAnimLoop +old=' }, [canvasSize]);' +new=' }, [canvasSize, labelMode]);' +if old in t: + t=t.replace(old,new,1) + +p.write_text(t,encoding='utf-8') +print('inserted label selector UI and fixed callback dependency') diff --git a/_fix_last_dep_networkmap.py b/_fix_last_dep_networkmap.py new file mode 100644 index 0000000..0a7cd89 --- /dev/null +++ b/_fix_last_dep_networkmap.py @@ -0,0 +1,10 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') +count=t.count('}, [canvasSize]);') +if count: + t=t.replace('}, [canvasSize]);','}, [canvasSize, labelMode]);') +# In case formatter created spaced variant +t=t.replace('}, [canvasSize ]);','}, [canvasSize, labelMode]);') +p.write_text(t,encoding='utf-8') +print('patched remaining canvasSize callback deps:', count) diff --git a/_harden_aup_scope_ui.py b/_harden_aup_scope_ui.py new file mode 100644 index 0000000..42cc1bd --- /dev/null +++ b/_harden_aup_scope_ui.py @@ -0,0 +1,71 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx') +t=p.read_text(encoding='utf-8') + +# Auto-select first hunt with datasets after load +old=''' const [tRes, hRes] = await Promise.all([ + keywords.listThemes(), + hunts.list(0, 200), + ]); + setThemes(tRes.themes); + setHuntList(hRes.hunts); +''' +new=''' const [tRes, hRes] = await Promise.all([ + keywords.listThemes(), + hunts.list(0, 200), + ]); + setThemes(tRes.themes); + setHuntList(hRes.hunts); + if (!selectedHuntId && hRes.hunts.length > 0) { + const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0]; + setSelectedHuntId(best.id); + } +''' +if old not in t: + raise SystemExit('loadData block not found') +t=t.replace(old,new) + +# Guard runScan +old2=''' const runScan = useCallback(async () => { + setScanning(true); + setScanResult(null); + try { +''' +new2=''' const runScan = useCallback(async () => { + if (!selectedHuntId) { + enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' }); + return; + } + if (selectedDs.size === 0) { + enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' }); + return; + } + + setScanning(true); + setScanResult(null); + try { +''' +if old2 not in t: + raise SystemExit('runScan header not found') +t=t.replace(old2,new2) + +# update loadData deps +old3=''' }, [enqueueSnackbar]); +''' +new3=''' }, [enqueueSnackbar, selectedHuntId]); +''' +if old3 not in t: + raise SystemExit('loadData deps not found') +t=t.replace(old3,new3,1) + +# disable button if no hunt or no datasets +old4=''' onClick={runScan} disabled={scanning} +''' +new4=''' onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0} +''' +if old4 not in t: + raise SystemExit('scan button props not found') +t=t.replace(old4,new4) + +p.write_text(t,encoding='utf-8') +print('hardened AUPScanner to require explicit hunt/dataset scope') diff --git a/_optimize_keywords_partial_cache.py b/_optimize_keywords_partial_cache.py new file mode 100644 index 0000000..39003ff --- /dev/null +++ b/_optimize_keywords_partial_cache.py @@ -0,0 +1,84 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py') +t=p.read_text(encoding='utf-8') +old=''' if can_use_cache: + themes = await scanner._load_themes(body.theme_ids) + allowed_theme_names = {t.name for t in themes} + keywords_scanned = sum(len(theme.keywords) for theme in themes) + + cached_entries: list[dict] = [] + missing: list[str] = [] + for dataset_id in (body.dataset_ids or []): + entry = keyword_scan_cache.get(dataset_id) + if not entry: + missing.append(dataset_id) + continue + cached_entries.append({"result": entry.result, "built_at": entry.built_at}) + + if not missing and cached_entries: + merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": True, + "cache_status": "hit", + "cached_at": merged["cached_at"], + } +''' +new=''' if can_use_cache: + themes = await scanner._load_themes(body.theme_ids) + allowed_theme_names = {t.name for t in themes} + keywords_scanned = sum(len(theme.keywords) for theme in themes) + + cached_entries: list[dict] = [] + missing: list[str] = [] + for dataset_id in (body.dataset_ids or []): + entry = keyword_scan_cache.get(dataset_id) + if not entry: + missing.append(dataset_id) + continue + cached_entries.append({"result": entry.result, "built_at": entry.built_at}) + + if not missing and cached_entries: + merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": True, + "cache_status": "hit", + "cached_at": merged["cached_at"], + } + + if missing: + missing_entries: list[dict] = [] + for dataset_id in missing: + partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids) + keyword_scan_cache.put(dataset_id, partial) + missing_entries.append({"result": partial, "built_at": None}) + + merged = _merge_cached_results( + cached_entries + missing_entries, + allowed_theme_names if body.theme_ids else None, + ) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": len(cached_entries) > 0, + "cache_status": "partial" if cached_entries else "miss", + "cached_at": merged["cached_at"], + } +''' +if old not in t: + raise SystemExit('cache block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated keyword /scan to use partial cache + scan missing datasets only') diff --git a/_optimize_scanner_keyset.py b/_optimize_scanner_keyset.py new file mode 100644 index 0000000..bf96cd2 --- /dev/null +++ b/_optimize_scanner_keyset.py @@ -0,0 +1,61 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py') +t=p.read_text(encoding='utf-8') +start=t.index(' async def _scan_datasets(') +end=t.index(' async def _scan_hunts', start) +new_func=''' async def _scan_datasets( + self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None + ) -> None: + """Scan dataset rows in batches using keyset pagination (no OFFSET).""" + ds_q = select(Dataset.id, Dataset.name) + if dataset_ids: + ds_q = ds_q.where(Dataset.id.in_(dataset_ids)) + ds_result = await self.db.execute(ds_q) + ds_map = {r[0]: r[1] for r in ds_result.fetchall()} + + if not ds_map: + return + + import asyncio + + for ds_id, ds_name in ds_map.items(): + last_id = 0 + while True: + rows_result = await self.db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds_id) + .where(DatasetRow.id > last_id) + .order_by(DatasetRow.id) + .limit(BATCH_SIZE) + ) + rows = rows_result.scalars().all() + if not rows: + break + + for row in rows: + result.rows_scanned += 1 + data = row.data or {} + for col_name, cell_value in data.items(): + if cell_value is None: + continue + text = str(cell_value) + self._match_text( + text, + patterns, + "dataset_row", + row.id, + col_name, + result.hits, + row_index=row.row_index, + dataset_name=ds_name, + ) + + last_id = rows[-1].id + await asyncio.sleep(0) + if len(rows) < BATCH_SIZE: + break + +''' +out=t[:start]+new_func+t[end:] +p.write_text(out,encoding='utf-8') +print('optimized scanner _scan_datasets to keyset pagination') diff --git a/_patch_inventory_stats.py b/_patch_inventory_stats.py new file mode 100644 index 0000000..59028dd --- /dev/null +++ b/_patch_inventory_stats.py @@ -0,0 +1,36 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') +old=''' return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + }, + } +''' +new=''' return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0, + }, + } +''' +if old not in t: + raise SystemExit('return block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('patched inventory stats metadata') diff --git a/_patch_inventory_stats2.py b/_patch_inventory_stats2.py new file mode 100644 index 0000000..c04e39a --- /dev/null +++ b/_patch_inventory_stats2.py @@ -0,0 +1,10 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') +needle=' "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n' +if '"row_budget_per_dataset"' not in t: + if needle not in t: + raise SystemExit('needle not found') + t=t.replace(needle, needle + ' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,\n') +p.write_text(t,encoding='utf-8') +print('inserted inventory budget stats lines') diff --git a/_patch_network_sleep.py b/_patch_network_sleep.py new file mode 100644 index 0000000..b9e89c7 --- /dev/null +++ b/_patch_network_sleep.py @@ -0,0 +1,14 @@ +from pathlib import Path + +p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx") +text = p.read_text(encoding="utf-8") + +anchor = " useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize]);\n" +insert = anchor + "\n const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms));\n" +if "const sleep = (ms: number)" not in text and anchor in text: + text = text.replace(anchor, insert) + +text = text.replace("await new Promise(r => setTimeout(r, delayMs + jitter));", "await sleep(delayMs + jitter);") + +p.write_text(text, encoding="utf-8") +print("Patched sleep helper + polling awaits") \ No newline at end of file diff --git a/_patch_network_wait.py b/_patch_network_wait.py new file mode 100644 index 0000000..2c4ed89 --- /dev/null +++ b/_patch_network_wait.py @@ -0,0 +1,37 @@ +from pathlib import Path +import re + +p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx") +text = p.read_text(encoding="utf-8") +pattern = re.compile(r"const waitUntilReady = async \(\): Promise => \{[\s\S]*?\n\s*\};", re.M) +replacement = '''const waitUntilReady = async (): Promise => { + // Poll inventory-status with exponential backoff until 'ready' (or cancelled) + setProgress('Host inventory is being prepared in the background'); + setLoading(true); + let delayMs = 1500; + const startedAt = Date.now(); + for (;;) { + const jitter = Math.floor(Math.random() * 250); + await new Promise(r => setTimeout(r, delayMs + jitter)); + if (cancelled) return false; + try { + const st = await network.inventoryStatus(selectedHuntId); + if (cancelled) return false; + if (st.status === 'ready') return true; + if (Date.now() - startedAt > 5 * 60 * 1000) { + setError('Host inventory build timed out. Please retry.'); + return false; + } + delayMs = Math.min(10000, Math.floor(delayMs * 1.5)); + // still building or none (job may not have started yet) - keep polling + } catch { + if (cancelled) return false; + delayMs = Math.min(10000, Math.floor(delayMs * 1.5)); + } + } + };''' +new_text, n = pattern.subn(replacement, text, count=1) +if n != 1: + raise SystemExit(f"Failed to patch waitUntilReady, matches={n}") +p.write_text(new_text, encoding="utf-8") +print("Patched waitUntilReady") \ No newline at end of file diff --git a/_perf_edit_config_inventory.py b/_perf_edit_config_inventory.py new file mode 100644 index 0000000..ecb550a --- /dev/null +++ b/_perf_edit_config_inventory.py @@ -0,0 +1,26 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=p.read_text(encoding='utf-8') +old=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field( + default=25000, + description="Row budget per dataset when building host inventory (0 = unlimited)", + ) +''' +new=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field( + default=5000, + description="Row budget per dataset when building host inventory (0 = unlimited)", + ) + NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field( + default=120000, + description="Global row budget across all datasets for host inventory build (0 = unlimited)", + ) + NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field( + default=120000, + description="Max unique connection tuples retained during host inventory build", + ) +''' +if old not in t: + raise SystemExit('network inventory block not found') +t=t.replace(old,new) +p.write_text(t,encoding='utf-8') +print('updated network inventory budgets in config') diff --git a/_perf_edit_host_inventory_budgets.py b/_perf_edit_host_inventory_budgets.py new file mode 100644 index 0000000..18f1716 --- /dev/null +++ b/_perf_edit_host_inventory_budgets.py @@ -0,0 +1,164 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') + +# insert budget vars near existing counters +old=''' connections: dict[tuple, int] = defaultdict(int) + total_rows = 0 + ds_with_hosts = 0 +''' +new=''' connections: dict[tuple, int] = defaultdict(int) + total_rows = 0 + ds_with_hosts = 0 + sampled_dataset_count = 0 + total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS)) + max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS)) + global_budget_reached = False + dropped_connections = 0 +''' +if old not in t: + raise SystemExit('counter block not found') +t=t.replace(old,new) + +# update batch size and sampled count increments + global budget checks +old2=''' batch_size = 10000 + max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET)) + rows_scanned_this_dataset = 0 + sampled_dataset = False + last_row_index = -1 + while True: +''' +new2=''' batch_size = 5000 + max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET)) + rows_scanned_this_dataset = 0 + sampled_dataset = False + last_row_index = -1 + while True: + if total_row_budget and total_rows >= total_row_budget: + global_budget_reached = True + break +''' +if old2 not in t: + raise SystemExit('batch block not found') +t=t.replace(old2,new2) + +old3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset: + sampled_dataset = True + break + + data = ro.data or {} + total_rows += 1 + rows_scanned_this_dataset += 1 +''' +new3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset: + sampled_dataset = True + break + if total_row_budget and total_rows >= total_row_budget: + sampled_dataset = True + global_budget_reached = True + break + + data = ro.data or {} + total_rows += 1 + rows_scanned_this_dataset += 1 +''' +if old3 not in t: + raise SystemExit('row scan block not found') +t=t.replace(old3,new3) + +# cap connection map growth +old4=''' for c in cols['remote_ip']: + rip = _clean(data.get(c)) + if _is_valid_ip(rip): + rport = '' + for pc in cols['remote_port']: + rport = _clean(data.get(pc)) + if rport: + break + connections[(host_key, rip, rport)] += 1 +''' +new4=''' for c in cols['remote_ip']: + rip = _clean(data.get(c)) + if _is_valid_ip(rip): + rport = '' + for pc in cols['remote_port']: + rport = _clean(data.get(pc)) + if rport: + break + conn_key = (host_key, rip, rport) + if max_connections and len(connections) >= max_connections and conn_key not in connections: + dropped_connections += 1 + continue + connections[conn_key] += 1 +''' +if old4 not in t: + raise SystemExit('connection block not found') +t=t.replace(old4,new4) + +# sampled_dataset counter +old5=''' if sampled_dataset: + logger.info( + "Host inventory row budget reached for dataset %s (%d rows)", + ds.id, + rows_scanned_this_dataset, + ) + break +''' +new5=''' if sampled_dataset: + sampled_dataset_count += 1 + logger.info( + "Host inventory row budget reached for dataset %s (%d rows)", + ds.id, + rows_scanned_this_dataset, + ) + break +''' +if old5 not in t: + raise SystemExit('sampled block not found') +t=t.replace(old5,new5) + +# break dataset loop if global budget reached +old6=''' if len(rows) < batch_size: + break + + # Post-process hosts +''' +new6=''' if len(rows) < batch_size: + break + + if global_budget_reached: + logger.info( + "Host inventory global row budget reached for hunt %s at %d rows", + hunt_id, + total_rows, + ) + break + + # Post-process hosts +''' +if old6 not in t: + raise SystemExit('post-process boundary block not found') +t=t.replace(old6,new6) + +# add stats +old7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0, + }, + } +''' +new7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS, + "connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS, + "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0, + "sampled_datasets": sampled_dataset_count, + "global_budget_reached": global_budget_reached, + "dropped_connections": dropped_connections, + }, + } +''' +if old7 not in t: + raise SystemExit('stats block not found') +t=t.replace(old7,new7) + +p.write_text(t,encoding='utf-8') +print('updated host inventory with global row and connection budgets') diff --git a/_perf_edit_networkmap_render.py b/_perf_edit_networkmap_render.py new file mode 100644 index 0000000..64eb34e --- /dev/null +++ b/_perf_edit_networkmap_render.py @@ -0,0 +1,39 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') +repls={ +"const LARGE_HUNT_SUBGRAPH_HOSTS = 350;":"const LARGE_HUNT_SUBGRAPH_HOSTS = 220;", +"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;":"const LARGE_HUNT_SUBGRAPH_EDGES = 1200;", +"const RENDER_SIMPLIFY_NODE_THRESHOLD = 220;":"const RENDER_SIMPLIFY_NODE_THRESHOLD = 120;", +"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;":"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500;", +"const EDGE_DRAW_TARGET = 1000;":"const EDGE_DRAW_TARGET = 600;" +} +for a,b in repls.items(): + if a not in t: + raise SystemExit(f'missing constant: {a}') + t=t.replace(a,b) + +old=''' // Then label hit (so clicking text works too) + for (const n of graph.nodes) { + if (isPointOnNodeLabel(n, wx, wy, vp)) return n; + } +''' +new=''' // Then label hit (so clicking text works too on manageable graph sizes) + if (graph.nodes.length <= 220) { + for (const n of graph.nodes) { + if (isPointOnNodeLabel(n, wx, wy, vp)) return n; + } + } +''' +if old not in t: + raise SystemExit('label hit block not found') +t=t.replace(old,new) + +old2='simulate(g, w / 2, h / 2, 60);' +if t.count(old2) < 2: + raise SystemExit('expected two simulate calls') +t=t.replace(old2,'simulate(g, w / 2, h / 2, 20);',1) +t=t.replace(old2,'simulate(g, w / 2, h / 2, 30);',1) + +p.write_text(t,encoding='utf-8') +print('tightened network map rendering + load limits') diff --git a/_perf_patch_backend.py b/_perf_patch_backend.py new file mode 100644 index 0000000..6dd38fa --- /dev/null +++ b/_perf_patch_backend.py @@ -0,0 +1,107 @@ +from pathlib import Path +# config updates +cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=cfg.read_text(encoding='utf-8') +anchor=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field( + default=3000, description="Hard cap for edges returned by network subgraph endpoint" + ) +''' +ins=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field( + default=3000, description="Hard cap for edges returned by network subgraph endpoint" + ) + NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field( + default=200000, + description="Row budget per dataset when building host inventory (0 = unlimited)", + ) +''' +if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t: + if anchor not in t: + raise SystemExit('config network anchor not found') + t=t.replace(anchor,ins) +cfg.write_text(t,encoding='utf-8') + +# host inventory updates +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') +if 'from app.config import settings' not in t: + t=t.replace('from app.db.models import Dataset, DatasetRow\n', 'from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n') + +t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n') + +old=''' for ro in rows: + data = ro.data or {} + total_rows += 1 + + fqdn = '' +''' +new=''' for ro in rows: + if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset: + sampled_dataset = True + break + + data = ro.data or {} + total_rows += 1 + rows_scanned_this_dataset += 1 + + fqdn = '' +''' +if old not in t: + raise SystemExit('row loop anchor not found') +t=t.replace(old,new) + +old2=''' last_row_index = rows[-1].row_index + if len(rows) < batch_size: + break +''' +new2=''' if sampled_dataset: + logger.info( + "Host inventory row budget reached for dataset %s (%d rows)", + ds.id, + rows_scanned_this_dataset, + ) + break + + last_row_index = rows[-1].row_index + if len(rows) < batch_size: + break +''' +if old2 not in t: + raise SystemExit('batch loop end anchor not found') +t=t.replace(old2,new2) + +old3=''' return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + }, + } +''' +new3=''' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 + + return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "sampled_mode": sampled, + }, + } +''' +if old3 not in t: + raise SystemExit('return stats anchor not found') +t=t.replace(old3,new3) + +p.write_text(t,encoding='utf-8') +print('patched config + host inventory row budget') diff --git a/_perf_patch_backend2.py b/_perf_patch_backend2.py new file mode 100644 index 0000000..004f254 --- /dev/null +++ b/_perf_patch_backend2.py @@ -0,0 +1,38 @@ +from pathlib import Path +cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py') +t=cfg.read_text(encoding='utf-8') +if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t: + t=t.replace( +''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field( + default=3000, description="Hard cap for edges returned by network subgraph endpoint" + ) +''', +''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field( + default=3000, description="Hard cap for edges returned by network subgraph endpoint" + ) + NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field( + default=200000, + description="Row budget per dataset when building host inventory (0 = unlimited)", + ) +''') +cfg.write_text(t,encoding='utf-8') + +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') +if 'from app.config import settings' not in t: + t=t.replace('from app.db.models import Dataset, DatasetRow\n','from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n') + +t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', + ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n') + +t=t.replace(' for ro in rows:\n data = ro.data or {}\n total_rows += 1\n\n', + ' for ro in rows:\n if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:\n sampled_dataset = True\n break\n\n data = ro.data or {}\n total_rows += 1\n rows_scanned_this_dataset += 1\n\n') + +t=t.replace(' last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n', + ' if sampled_dataset:\n logger.info(\n "Host inventory row budget reached for dataset %s (%d rows)",\n ds.id,\n rows_scanned_this_dataset,\n )\n break\n\n last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n') + +t=t.replace(' return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n },\n }\n', + ' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0\n\n return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": sampled,\n },\n }\n') + +p.write_text(t,encoding='utf-8') +print('patched backend inventory performance settings') diff --git a/_perf_patch_networkmap.py b/_perf_patch_networkmap.py new file mode 100644 index 0000000..3379eb0 --- /dev/null +++ b/_perf_patch_networkmap.py @@ -0,0 +1,220 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') + +# constants +if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t: + t=t.replace( +"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n", +"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n") + +# drawBackground signature + t_old='''function drawBackground( + ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number, +) { +''' +if t_old in t: + t=t.replace(t_old, +'''function drawBackground( + ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number, + simplify: boolean, +) { +''') + +# skip grid when simplify +if 'if (!simplify) {' not in t: + t=t.replace( +''' ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } + } + ctx.restore(); +''', +''' if (!simplify) { + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } + } + ctx.restore(); + } +''') + +# drawEdges signature + t=t.replace('''function drawEdges( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + nodeMap: Map, animTime: number, +) { + for (const e of graph.edges) { +''', +'''function drawEdges( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + nodeMap: Map, animTime: number, + simplify: boolean, +) { + const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1; + for (let ei = 0; ei < graph.edges.length; ei += edgeStep) { + const e = graph.edges[ei]; +''') + +# simplify edge path + t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', + 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }') + +t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', + 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }') + +# reduce glow when simplify + t=t.replace(''' ctx.save(); + ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8; + ctx.strokeStyle = 'rgba(96,165,250,0.3)'; + ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2); + ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); } + ctx.stroke(); ctx.restore(); +''', +''' if (!simplify) { + ctx.save(); + ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8; + ctx.strokeStyle = 'rgba(96,165,250,0.3)'; + ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2); + ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); } + ctx.stroke(); ctx.restore(); + } +''') + +# drawLabels signature and early return + t=t.replace('''function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, +) { +''', +'''function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, + simplify: boolean, +) { +''') + +if 'if (simplify && !search && !hovered && !selected) {' not in t: + t=t.replace(' const dimmed = search.length > 0;\n', + ' const dimmed = search.length > 0;\n if (simplify && !search && !hovered && !selected) {\n return;\n }\n') + +# drawGraph adapt + t=t.replace(''' drawBackground(ctx, w, h, vp, dpr); + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime); + drawNodes(ctx, graph, hovered, selected, search, matchSet); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp); + ctx.restore(); +''', +''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD; + drawBackground(ctx, w, h, vp, dpr, simplify); + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify); + drawNodes(ctx, graph, hovered, selected, search, matchSet); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify); + ctx.restore(); +''') + +# hover RAF ref +if 'const hoverRafRef = useRef(0);' not in t: + t=t.replace(' const graphRef = useRef(null);\n', ' const graphRef = useRef(null);\n const hoverRafRef = useRef(0);\n') + +# throttle hover hit test on mousemove +old_mm=''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + setHovered(node?.id ?? null); + }, [graph, redraw, startAnimLoop]); +''' +new_mm=''' cancelAnimationFrame(hoverRafRef.current); + const clientX = e.clientX; + const clientY = e.clientY; + hoverRafRef.current = requestAnimationFrame(() => { + const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current); + setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null))); + }); + }, [graph, redraw, startAnimLoop]); +''' +if old_mm in t: + t=t.replace(old_mm,new_mm) + +# cleanup hover raf on unmount in existing animation cleanup effect +if 'cancelAnimationFrame(hoverRafRef.current);' not in t: + t=t.replace(''' useEffect(() => { + if (graph) startAnimLoop(); + return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; }; + }, [graph, startAnimLoop]); +''', +''' useEffect(() => { + if (graph) startAnimLoop(); + return () => { + cancelAnimationFrame(animFrameRef.current); + cancelAnimationFrame(hoverRafRef.current); + isAnimatingRef.current = false; + }; + }, [graph, startAnimLoop]); +''') + +# connectedNodes optimization map +if 'const nodeById = useMemo(() => {' not in t: + t=t.replace(''' const connectionCount = selectedNode && graph + ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length + : 0; + + const connectedNodes = useMemo(() => { +''', +''' const connectionCount = selectedNode && graph + ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length + : 0; + + const nodeById = useMemo(() => { + const m = new Map(); + if (!graph) return m; + for (const n of graph.nodes) m.set(n.id, n); + return m; + }, [graph]); + + const connectedNodes = useMemo(() => { +''') + + t=t.replace(''' const n = graph.nodes.find(x => x.id === e.target); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); + } else if (e.target === selectedNode.id) { + const n = graph.nodes.find(x => x.id === e.source); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); +''', +''' const n = nodeById.get(e.target); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); + } else if (e.target === selectedNode.id) { + const n = nodeById.get(e.source); + if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); +''') + + t=t.replace(' }, [selectedNode, graph]);\n', ' }, [selectedNode, graph, nodeById]);\n') + +p.write_text(t,encoding='utf-8') +print('patched NetworkMap adaptive render + hover throttle') diff --git a/_perf_patch_networkmap2.py b/_perf_patch_networkmap2.py new file mode 100644 index 0000000..1b46114 --- /dev/null +++ b/_perf_patch_networkmap2.py @@ -0,0 +1,153 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx') +t=p.read_text(encoding='utf-8') + +if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t: + t=t.replace('const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n', 'const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n') + +t=t.replace('function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n) {', 'function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n simplify: boolean,\n) {') + +t=t.replace(''' ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } + } + ctx.restore(); +''',''' if (!simplify) { + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } + } + ctx.restore(); + } +''') + +t=t.replace('''function drawEdges( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + nodeMap: Map, animTime: number, +) { + for (const e of graph.edges) { +''','''function drawEdges( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + nodeMap: Map, animTime: number, + simplify: boolean, +) { + const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1; + for (let ei = 0; ei < graph.edges.length; ei += edgeStep) { + const e = graph.edges[ei]; +''') + +t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }') + +t=t.replace('''function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, +) { + const dimmed = search.length > 0; +''','''function drawLabels( + ctx: CanvasRenderingContext2D, graph: Graph, + hovered: string | null, selected: string | null, + search: string, matchSet: Set, vp: Viewport, + simplify: boolean, +) { + const dimmed = search.length > 0; + if (simplify && !search && !hovered && !selected) { + return; + } +''') + +t=t.replace(''' drawBackground(ctx, w, h, vp, dpr); + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime); + drawNodes(ctx, graph, hovered, selected, search, matchSet); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp); + ctx.restore(); +''',''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD; + drawBackground(ctx, w, h, vp, dpr, simplify); + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify); + drawNodes(ctx, graph, hovered, selected, search, matchSet); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify); + ctx.restore(); +''') + +if 'const hoverRafRef = useRef(0);' not in t: + t=t.replace('const graphRef = useRef(null);\n', 'const graphRef = useRef(null);\n const hoverRafRef = useRef(0);\n') + +t=t.replace(''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); + setHovered(node?.id ?? null); + }, [graph, redraw, startAnimLoop]); +''',''' cancelAnimationFrame(hoverRafRef.current); + const clientX = e.clientX; + const clientY = e.clientY; + hoverRafRef.current = requestAnimationFrame(() => { + const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current); + setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null))); + }); + }, [graph, redraw, startAnimLoop]); +''') + +t=t.replace(''' useEffect(() => { + if (graph) startAnimLoop(); + return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; }; + }, [graph, startAnimLoop]); +''',''' useEffect(() => { + if (graph) startAnimLoop(); + return () => { + cancelAnimationFrame(animFrameRef.current); + cancelAnimationFrame(hoverRafRef.current); + isAnimatingRef.current = false; + }; + }, [graph, startAnimLoop]); +''') + +if 'const nodeById = useMemo(() => {' not in t: + t=t.replace(''' const connectionCount = selectedNode && graph + ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length + : 0; + + const connectedNodes = useMemo(() => { +''',''' const connectionCount = selectedNode && graph + ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length + : 0; + + const nodeById = useMemo(() => { + const m = new Map(); + if (!graph) return m; + for (const n of graph.nodes) m.set(n.id, n); + return m; + }, [graph]); + + const connectedNodes = useMemo(() => { +''') + +t=t.replace('const n = graph.nodes.find(x => x.id === e.target);','const n = nodeById.get(e.target);') +t=t.replace('const n = graph.nodes.find(x => x.id === e.source);','const n = nodeById.get(e.source);') +t=t.replace(' }, [selectedNode, graph]);',' }, [selectedNode, graph, nodeById]);') + +p.write_text(t,encoding='utf-8') +print('patched NetworkMap performance') diff --git a/_perf_replace_build_host_inventory.py b/_perf_replace_build_host_inventory.py new file mode 100644 index 0000000..015d7ab --- /dev/null +++ b/_perf_replace_build_host_inventory.py @@ -0,0 +1,227 @@ +from pathlib import Path +p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py') +t=p.read_text(encoding='utf-8') +start=t.index('async def build_host_inventory(') +# find end of function by locating '\n\n' before EOF after ' }\n' +end=t.index('\n\n', start) +# need proper end: first double newline after function may occur in docstring? compute by searching for '\n\n' after ' }\n' near end +ret_idx=t.rfind(' }') +# safer locate end as last occurrence of '\n }\n' after start, then function ends next newline +end=t.find('\n\n', ret_idx) +if end==-1: + end=len(t) +new_func='''async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: + """Build a deduplicated host inventory from all datasets in a hunt. + + Returns dict with 'hosts', 'connections', and 'stats'. + Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count. + """ + ds_result = await db.execute( + select(Dataset).where(Dataset.hunt_id == hunt_id) + ) + all_datasets = ds_result.scalars().all() + + if not all_datasets: + return {"hosts": [], "connections": [], "stats": { + "total_hosts": 0, "total_datasets_scanned": 0, + "total_rows_scanned": 0, + }} + + hosts: dict[str, dict] = {} # fqdn -> host record + ip_to_host: dict[str, str] = {} # local-ip -> fqdn + connections: dict[tuple, int] = defaultdict(int) + total_rows = 0 + ds_with_hosts = 0 + sampled_dataset_count = 0 + total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS)) + max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS)) + global_budget_reached = False + dropped_connections = 0 + + for ds in all_datasets: + if total_row_budget and total_rows >= total_row_budget: + global_budget_reached = True + break + + cols = _identify_columns(ds) + if not cols['fqdn'] and not cols['host_id']: + continue + ds_with_hosts += 1 + + batch_size = 5000 + max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET)) + rows_scanned_this_dataset = 0 + sampled_dataset = False + last_row_index = -1 + + while True: + if total_row_budget and total_rows >= total_row_budget: + sampled_dataset = True + global_budget_reached = True + break + + rr = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .where(DatasetRow.row_index > last_row_index) + .order_by(DatasetRow.row_index) + .limit(batch_size) + ) + rows = rr.scalars().all() + if not rows: + break + + for ro in rows: + if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset: + sampled_dataset = True + break + if total_row_budget and total_rows >= total_row_budget: + sampled_dataset = True + global_budget_reached = True + break + + data = ro.data or {} + total_rows += 1 + rows_scanned_this_dataset += 1 + + fqdn = '' + for c in cols['fqdn']: + fqdn = _clean(data.get(c)) + if fqdn: + break + client_id = '' + for c in cols['host_id']: + client_id = _clean(data.get(c)) + if client_id: + break + + if not fqdn and not client_id: + continue + + host_key = fqdn or client_id + + if host_key not in hosts: + short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn + hosts[host_key] = { + 'id': host_key, + 'hostname': short or client_id, + 'fqdn': fqdn, + 'client_id': client_id, + 'ips': set(), + 'os': '', + 'users': set(), + 'datasets': set(), + 'row_count': 0, + } + + h = hosts[host_key] + h['datasets'].add(ds.name) + h['row_count'] += 1 + if client_id and not h['client_id']: + h['client_id'] = client_id + + for c in cols['username']: + u = _extract_username(_clean(data.get(c))) + if u: + h['users'].add(u) + + for c in cols['local_ip']: + ip = _clean(data.get(c)) + if _is_valid_ip(ip): + h['ips'].add(ip) + ip_to_host[ip] = host_key + + for c in cols['os']: + ov = _clean(data.get(c)) + if ov and not h['os']: + h['os'] = ov + + for c in cols['remote_ip']: + rip = _clean(data.get(c)) + if _is_valid_ip(rip): + rport = '' + for pc in cols['remote_port']: + rport = _clean(data.get(pc)) + if rport: + break + conn_key = (host_key, rip, rport) + if max_connections and len(connections) >= max_connections and conn_key not in connections: + dropped_connections += 1 + continue + connections[conn_key] += 1 + + if sampled_dataset: + sampled_dataset_count += 1 + logger.info( + "Host inventory sampling for dataset %s (%d rows scanned)", + ds.id, + rows_scanned_this_dataset, + ) + break + + last_row_index = rows[-1].row_index + if len(rows) < batch_size: + break + + if global_budget_reached: + logger.info( + "Host inventory global row budget reached for hunt %s at %d rows", + hunt_id, + total_rows, + ) + break + + # Post-process hosts + for h in hosts.values(): + if not h['os'] and h['fqdn']: + h['os'] = _infer_os(h['fqdn']) + h['ips'] = sorted(h['ips']) + h['users'] = sorted(h['users']) + h['datasets'] = sorted(h['datasets']) + + # Build connections, resolving IPs to host keys + conn_list = [] + seen = set() + for (src, dst_ip, dst_port), cnt in connections.items(): + if dst_ip in _IGNORE_IPS: + continue + dst_host = ip_to_host.get(dst_ip, '') + if dst_host == src: + continue + key = tuple(sorted([src, dst_host or dst_ip])) + if key in seen: + continue + seen.add(key) + conn_list.append({ + 'source': src, + 'target': dst_host or dst_ip, + 'target_ip': dst_ip, + 'port': dst_port, + 'count': cnt, + }) + + host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True) + + return { + "hosts": host_list, + "connections": conn_list, + "stats": { + "total_hosts": len(host_list), + "total_datasets_scanned": len(all_datasets), + "datasets_with_hosts": ds_with_hosts, + "total_rows_scanned": total_rows, + "hosts_with_ips": sum(1 for h in host_list if h['ips']), + "hosts_with_users": sum(1 for h in host_list if h['users']), + "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS, + "connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS, + "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0, + "sampled_datasets": sampled_dataset_count, + "global_budget_reached": global_budget_reached, + "dropped_connections": dropped_connections, + }, + } +''' +out=t[:start]+new_func+t[end:] +p.write_text(out,encoding='utf-8') +print('replaced build_host_inventory with hard-budget fast mode') diff --git a/backend/alembic/versions/b2c3d4e5f6a7_add_playbooks_saved_searches.py b/backend/alembic/versions/b2c3d4e5f6a7_add_playbooks_saved_searches.py new file mode 100644 index 0000000..c4792df --- /dev/null +++ b/backend/alembic/versions/b2c3d4e5f6a7_add_playbooks_saved_searches.py @@ -0,0 +1,78 @@ +"""add playbooks, playbook_steps, saved_searches tables + +Revision ID: b2c3d4e5f6a7 +Revises: a1b2c3d4e5f6 +Create Date: 2026-02-21 10:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "b2c3d4e5f6a7" +down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add display_name to users table + with op.batch_alter_table("users") as batch_op: + batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True)) + + # Create playbooks table + op.create_table( + "playbooks", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("name", sa.String(256), nullable=False, index=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True), + sa.Column("is_template", sa.Boolean(), server_default="0"), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True), + sa.Column("status", sa.String(20), server_default="active"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # Create playbook_steps table + op.create_table( + "playbook_steps", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False), + sa.Column("order_index", sa.Integer(), nullable=False), + sa.Column("title", sa.String(256), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("step_type", sa.String(32), server_default="manual"), + sa.Column("target_route", sa.String(256), nullable=True), + sa.Column("is_completed", sa.Boolean(), server_default="0"), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("notes", sa.Text(), nullable=True), + ) + op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"]) + + # Create saved_searches table + op.create_table( + "saved_searches", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("name", sa.String(256), nullable=False, index=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("search_type", sa.String(32), nullable=False), + sa.Column("query_params", sa.JSON(), nullable=False), + sa.Column("threshold", sa.Float(), nullable=True), + sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_result_count", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"]) + + +def downgrade() -> None: + op.drop_table("saved_searches") + op.drop_table("playbook_steps") + op.drop_table("playbooks") + with op.batch_alter_table("users") as batch_op: + batch_op.drop_column("display_name") diff --git a/backend/alembic/versions/c3d4e5f6a7b8_add_processing_tasks_table.py b/backend/alembic/versions/c3d4e5f6a7b8_add_processing_tasks_table.py new file mode 100644 index 0000000..112a53e --- /dev/null +++ b/backend/alembic/versions/c3d4e5f6a7b8_add_processing_tasks_table.py @@ -0,0 +1,48 @@ +"""add processing_tasks table + +Revision ID: c3d4e5f6a7b8 +Revises: b2c3d4e5f6a7 +Create Date: 2026-02-22 00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "c3d4e5f6a7b8" +down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "processing_tasks", + sa.Column("id", sa.String(32), primary_key=True), + sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True), + sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True), + sa.Column("job_id", sa.String(64), nullable=True), + sa.Column("stage", sa.String(64), nullable=False), + sa.Column("status", sa.String(20), nullable=False, server_default="queued"), + sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("error", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"]) + op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"]) + op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"]) + op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"]) + + +def downgrade() -> None: + op.drop_index("ix_processing_tasks_status", table_name="processing_tasks") + op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks") + op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks") + op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks") + op.drop_table("processing_tasks") diff --git a/backend/app/agents/__init__.py b/backend/app/agents/__init__.py index 85f5987..967741b 100644 --- a/backend/app/agents/__init__.py +++ b/backend/app/agents/__init__.py @@ -1,16 +1,18 @@ -"""Analyst-assist agent module for ThreatHunt. +"""Analyst-assist agent module for ThreatHunt. Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses. Agents are advisory only and do not execute actions or modify data. """ -from .core import ThreatHuntAgent -from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider +from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective +from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider __all__ = [ "ThreatHuntAgent", - "LLMProvider", - "LocalProvider", - "NetworkedProvider", - "OnlineProvider", + "AgentContext", + "AgentResponse", + "Perspective", + "OllamaProvider", + "OpenWebUIProvider", + "EmbeddingProvider", ] diff --git a/backend/app/agents/core.py b/backend/app/agents/core.py deleted file mode 100644 index abce202..0000000 --- a/backend/app/agents/core.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Core ThreatHunt analyst-assist agent. - -Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses. -Agents are advisory only - no execution, no alerts, no data modifications. -""" - -import logging -from typing import Optional -from pydantic import BaseModel, Field - -from .providers import LLMProvider, get_provider - -logger = logging.getLogger(__name__) - - -class AgentContext(BaseModel): - """Context for agent guidance requests.""" - - query: str = Field( - ..., description="Analyst question or request for guidance" - ) - dataset_name: Optional[str] = Field(None, description="Name of CSV dataset") - artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)") - host_identifier: Optional[str] = Field( - None, description="Host name, IP, or identifier" - ) - data_summary: Optional[str] = Field( - None, description="Brief description of uploaded data" - ) - conversation_history: Optional[list[dict]] = Field( - default_factory=list, description="Previous messages in conversation" - ) - - -class AgentResponse(BaseModel): - """Response from analyst-assist agent.""" - - guidance: str = Field(..., description="Advisory guidance for analyst") - confidence: float = Field( - ..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)" - ) - suggested_pivots: list[str] = Field( - default_factory=list, description="Suggested analytical directions" - ) - suggested_filters: list[str] = Field( - default_factory=list, description="Suggested data filters or queries" - ) - caveats: Optional[str] = Field( - None, description="Assumptions, limitations, or caveats" - ) - reasoning: Optional[str] = Field( - None, description="Explanation of how guidance was generated" - ) - - -class ThreatHuntAgent: - """Analyst-assist agent for ThreatHunt. - - Provides guidance on: - - Interpreting CSV artifact data - - Suggesting analytical pivots and filters - - Forming and testing hypotheses - - Policy: - - Advisory guidance only (no execution) - - No database or schema changes - - No alert escalation - - Transparent reasoning - """ - - def __init__(self, provider: Optional[LLMProvider] = None): - """Initialize agent with LLM provider. - - Args: - provider: LLM provider instance. If None, uses get_provider() with auto mode. - """ - if provider is None: - try: - provider = get_provider("auto") - except RuntimeError as e: - logger.warning(f"Could not initialize default provider: {e}") - provider = None - - self.provider = provider - self.system_prompt = self._build_system_prompt() - - def _build_system_prompt(self) -> str: - """Build the system prompt that governs agent behavior.""" - return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform. - -Your role: -- Interpret and explain CSV artifact data from Velociraptor -- Suggest analytical pivots, filters, and hypotheses -- Highlight anomalies, patterns, or points of interest -- Guide analysts without replacing their judgment - -Your constraints: -- You ONLY provide guidance and suggestions -- You do NOT execute actions or tools -- You do NOT modify data or escalate alerts -- You do NOT make autonomous decisions -- You ONLY analyze data presented to you -- You explain your reasoning transparently -- You acknowledge limitations and assumptions -- You suggest next investigative steps - -When responding: -1. Start with a clear, direct answer to the query -2. Explain your reasoning based on the data context provided -3. Suggest 2-4 analytical pivots the analyst might explore -4. Suggest 2-4 data filters or queries that might be useful -5. Include relevant caveats or assumptions -6. Be honest about what you cannot determine from the data - -Remember: The analyst is the decision-maker. You are an assistant.""" - - async def assist(self, context: AgentContext) -> AgentResponse: - """Provide guidance on artifact data and analysis. - - Args: - context: Request context including query and data context. - - Returns: - Guidance response with suggestions and reasoning. - - Raises: - RuntimeError: If no provider is available. - """ - if not self.provider: - raise RuntimeError( - "No LLM provider available. Configure at least one of: " - "THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, " - "or THREAT_HUNT_ONLINE_API_KEY" - ) - - # Build prompt with context - prompt = self._build_prompt(context) - - try: - # Get guidance from LLM provider - guidance = await self.provider.generate(prompt, max_tokens=1024) - - # Parse response into structured format - response = self._parse_response(guidance, context) - - logger.info( - f"Agent assisted with query: {context.query[:50]}... " - f"(dataset: {context.dataset_name})" - ) - - return response - - except Exception as e: - logger.error(f"Error generating guidance: {e}") - raise - - def _build_prompt(self, context: AgentContext) -> str: - """Build the prompt for the LLM.""" - prompt_parts = [ - f"Analyst query: {context.query}", - ] - - if context.dataset_name: - prompt_parts.append(f"Dataset: {context.dataset_name}") - - if context.artifact_type: - prompt_parts.append(f"Artifact type: {context.artifact_type}") - - if context.host_identifier: - prompt_parts.append(f"Host: {context.host_identifier}") - - if context.data_summary: - prompt_parts.append(f"Data summary: {context.data_summary}") - - if context.conversation_history: - prompt_parts.append("\nConversation history:") - for msg in context.conversation_history[-5:]: # Last 5 messages for context - prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}") - - return "\n".join(prompt_parts) - - def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse: - """Parse LLM response into structured format. - - Note: This is a simplified parser. In production, use structured output - from the LLM (JSON mode, function calling, etc.) for better reliability. - """ - # For now, return a structured response based on the raw guidance - # In production, parse JSON or use structured output from LLM - return AgentResponse( - guidance=response_text, - confidence=0.8, # Placeholder - suggested_pivots=[ - "Analyze temporal patterns", - "Cross-reference with known indicators", - "Examine outliers in the dataset", - "Compare with baseline behavior", - ], - suggested_filters=[ - "Filter by high-risk indicators", - "Sort by timestamp for timeline analysis", - "Group by host or user", - "Filter by anomaly score", - ], - caveats="Guidance is based on available data context. " - "Analysts should verify findings with additional sources.", - reasoning="Analysis generated based on artifact data patterns and analyst query.", - ) diff --git a/backend/app/agents/providers.py b/backend/app/agents/providers.py deleted file mode 100644 index 16afea0..0000000 --- a/backend/app/agents/providers.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Pluggable LLM provider interface for analyst-assist agents. - -Supports three provider types: -- Local: On-device or on-prem models -- Networked: Shared internal inference services -- Online: External hosted APIs -""" - -import os -from abc import ABC, abstractmethod -from typing import Optional - - -class LLMProvider(ABC): - """Abstract base class for LLM providers.""" - - @abstractmethod - async def generate(self, prompt: str, max_tokens: int = 1024) -> str: - """Generate a response from the LLM. - - Args: - prompt: The input prompt - max_tokens: Maximum tokens in response - - Returns: - Generated text response - """ - pass - - @abstractmethod - def is_available(self) -> bool: - """Check if provider backend is available.""" - pass - - -class LocalProvider(LLMProvider): - """Local LLM provider (on-device or on-prem models).""" - - def __init__(self, model_path: Optional[str] = None): - """Initialize local provider. - - Args: - model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var. - """ - self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH") - self.model = None - - def is_available(self) -> bool: - """Check if local model is available.""" - if not self.model_path: - return False - # In production, would verify model file exists and can be loaded - return os.path.exists(str(self.model_path)) - - async def generate(self, prompt: str, max_tokens: int = 1024) -> str: - """Generate response using local model. - - Note: This is a placeholder. In production, integrate with: - - llama-cpp-python for GGML models - - Ollama API - - vLLM - - Other local inference engines - """ - if not self.is_available(): - raise RuntimeError("Local model not available") - - # Placeholder implementation - return f"[Local model response to: {prompt[:50]}...]" - - -class NetworkedProvider(LLMProvider): - """Networked LLM provider (shared internal inference services).""" - - def __init__( - self, - api_endpoint: Optional[str] = None, - api_key: Optional[str] = None, - model_name: str = "default", - ): - """Initialize networked provider. - - Args: - api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT. - api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY. - model_name: Model name/ID on the service. - """ - self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT") - self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY") - self.model_name = model_name - - def is_available(self) -> bool: - """Check if networked service is available.""" - return bool(self.api_endpoint) - - async def generate(self, prompt: str, max_tokens: int = 1024) -> str: - """Generate response using networked service. - - Note: This is a placeholder. In production, integrate with: - - Internal inference service API - - LLM inference container cluster - - Enterprise inference gateway - """ - if not self.is_available(): - raise RuntimeError("Networked service not available") - - # Placeholder implementation - return f"[Networked response from {self.model_name}: {prompt[:50]}...]" - - -class OnlineProvider(LLMProvider): - """Online LLM provider (external hosted APIs).""" - - def __init__( - self, - api_provider: str = "openai", - api_key: Optional[str] = None, - model_name: Optional[str] = None, - ): - """Initialize online provider. - - Args: - api_provider: Provider name (openai, anthropic, google, etc.) - api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY. - model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL. - """ - self.api_provider = api_provider - self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY") - self.model_name = model_name or os.getenv( - "THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default" - ) - - def is_available(self) -> bool: - """Check if online API is available.""" - return bool(self.api_key) - - async def generate(self, prompt: str, max_tokens: int = 1024) -> str: - """Generate response using online API. - - Note: This is a placeholder. In production, integrate with: - - OpenAI API (GPT-3.5, GPT-4, etc.) - - Anthropic Claude API - - Google Gemini API - - Other hosted LLM services - """ - if not self.is_available(): - raise RuntimeError("Online API not available or API key not set") - - # Placeholder implementation - return f"[Online {self.api_provider} response: {prompt[:50]}...]" - - -def get_provider(provider_type: str = "auto") -> LLMProvider: - """Get an LLM provider based on configuration. - - Args: - provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'. - 'auto' attempts to use the first available provider in order: - local -> networked -> online. - - Returns: - Configured LLM provider instance. - - Raises: - RuntimeError: If no provider is available. - """ - # Explicit provider selection - if provider_type == "local": - provider = LocalProvider() - elif provider_type == "networked": - provider = NetworkedProvider() - elif provider_type == "online": - provider = OnlineProvider() - elif provider_type == "auto": - # Try providers in order of preference - for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]: - provider = Provider() - if provider.is_available(): - return provider - raise RuntimeError( - "No LLM provider available. Configure at least one of: " - "THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, " - "or THREAT_HUNT_ONLINE_API_KEY" - ) - else: - raise ValueError(f"Unknown provider type: {provider_type}") - - if not provider.is_available(): - raise RuntimeError(f"{provider_type} provider not available") - - return provider diff --git a/backend/app/api/routes/agent.py b/backend/app/api/routes/agent.py deleted file mode 100644 index e58b2e9..0000000 --- a/backend/app/api/routes/agent.py +++ /dev/null @@ -1,170 +0,0 @@ -"""API routes for analyst-assist agent.""" - -import logging -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field - -from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse -from app.agents.config import AgentConfig - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/api/agent", tags=["agent"]) - -# Global agent instance (lazy-loaded) -_agent: ThreatHuntAgent | None = None - - -def get_agent() -> ThreatHuntAgent: - """Get or create the agent instance.""" - global _agent - if _agent is None: - if not AgentConfig.is_agent_enabled(): - raise HTTPException( - status_code=503, - detail="Analyst-assist agent is not configured. " - "Please configure an LLM provider.", - ) - _agent = ThreatHuntAgent() - return _agent - - -class AssistRequest(BaseModel): - """Request for agent assistance.""" - - query: str = Field( - ..., description="Analyst question or request for guidance" - ) - dataset_name: str | None = Field( - None, description="Name of CSV dataset being analyzed" - ) - artifact_type: str | None = Field( - None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)" - ) - host_identifier: str | None = Field( - None, description="Host name, IP address, or identifier" - ) - data_summary: str | None = Field( - None, description="Brief summary or context about the uploaded data" - ) - conversation_history: list[dict] | None = Field( - None, description="Previous messages for context" - ) - - -class AssistResponse(BaseModel): - """Response with agent guidance.""" - - guidance: str - confidence: float - suggested_pivots: list[str] - suggested_filters: list[str] - caveats: str | None = None - reasoning: str | None = None - - -@router.post( - "/assist", - response_model=AssistResponse, - summary="Get analyst-assist guidance", - description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. " - "Agent provides advisory guidance only - no execution.", -) -async def agent_assist(request: AssistRequest) -> AssistResponse: - """Provide analyst-assist guidance on artifact data. - - The agent will: - - Explain and interpret the provided data context - - Suggest analytical pivots the analyst might explore - - Suggest data filters or queries that might be useful - - Highlight assumptions, limitations, and caveats - - The agent will NOT: - - Execute any tools or actions - - Escalate findings to alerts - - Modify any data or schema - - Make autonomous decisions - - Args: - request: Assistance request with query and context - - Returns: - Guidance response with suggestions and reasoning - - Raises: - HTTPException: If agent is not configured (503) or request fails - """ - try: - agent = get_agent() - - # Build context - context = AgentContext( - query=request.query, - dataset_name=request.dataset_name, - artifact_type=request.artifact_type, - host_identifier=request.host_identifier, - data_summary=request.data_summary, - conversation_history=request.conversation_history or [], - ) - - # Get guidance - response = await agent.assist(context) - - logger.info( - f"Agent assisted analyst with query: {request.query[:50]}... " - f"(host: {request.host_identifier}, artifact: {request.artifact_type})" - ) - - return AssistResponse( - guidance=response.guidance, - confidence=response.confidence, - suggested_pivots=response.suggested_pivots, - suggested_filters=response.suggested_filters, - caveats=response.caveats, - reasoning=response.reasoning, - ) - - except RuntimeError as e: - logger.error(f"Agent error: {e}") - raise HTTPException( - status_code=503, - detail=f"Agent unavailable: {str(e)}", - ) - except Exception as e: - logger.exception(f"Unexpected error in agent_assist: {e}") - raise HTTPException( - status_code=500, - detail="Error generating guidance. Please try again.", - ) - - -@router.get( - "/health", - summary="Check agent health", - description="Check if agent is configured and ready to assist.", -) -async def agent_health() -> dict: - """Check agent availability and configuration. - - Returns: - Health status with configuration details - """ - try: - agent = get_agent() - provider_type = agent.provider.__class__.__name__ if agent.provider else "None" - return { - "status": "healthy", - "provider": provider_type, - "max_tokens": AgentConfig.MAX_RESPONSE_TOKENS, - "reasoning_enabled": AgentConfig.ENABLE_REASONING, - } - except HTTPException: - return { - "status": "unavailable", - "reason": "No LLM provider configured", - "configured_providers": { - "local": bool(AgentConfig.LOCAL_MODEL_PATH), - "networked": bool(AgentConfig.NETWORKED_ENDPOINT), - "online": bool(AgentConfig.ONLINE_API_KEY), - }, - } diff --git a/backend/app/api/routes/agent_v2.py b/backend/app/api/routes/agent_v2.py index 8035433..0eb5cac 100644 --- a/backend/app/api/routes/agent_v2.py +++ b/backend/app/api/routes/agent_v2.py @@ -1,4 +1,4 @@ -"""API routes for analyst-assist agent — v2. +"""API routes for analyst-assist agent v2. Supports quick, deep, and debate modes with streaming. Conversations are persisted to the database. @@ -6,19 +6,25 @@ Conversations are persisted to the database. import json import logging +import re +import time +from collections import Counter +from urllib.parse import urlparse from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.db import get_db -from app.db.models import Conversation, Message +from app.db.models import Conversation, Message, Dataset, KeywordTheme from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective from app.agents.providers_v2 import check_all_nodes from app.agents.registry import registry from app.services.sans_rag import sans_rag +from app.services.scanner import KeywordScanner logger = logging.getLogger(__name__) @@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent: return _agent -# ── Request / Response models ───────────────────────────────────────── +# Request / Response models class AssistRequest(BaseModel): @@ -52,6 +58,8 @@ class AssistRequest(BaseModel): model_override: str | None = None conversation_id: str | None = Field(None, description="Persist messages to this conversation") hunt_id: str | None = None + execution_preference: str = Field(default="auto", description="auto | force | off") + learning_mode: bool = False class AssistResponseModel(BaseModel): @@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel): node_used: str = "" latency_ms: int = 0 perspectives: list[dict] | None = None + execution: dict | None = None conversation_id: str | None = None -# ── Routes ──────────────────────────────────────────────────────────── +POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"} +POLICY_QUERY_TERMS = { + "policy", "violating", "violation", "browser history", "web history", + "domain", "domains", "adult", "gambling", "piracy", "aup", +} +WEB_DATASET_HINTS = { + "web", "history", "browser", "url", "visited_url", "domain", "title", +} + + +def _is_policy_domain_query(query: str) -> bool: + q = (query or "").lower() + if not q: + return False + score = sum(1 for t in POLICY_QUERY_TERMS if t in q) + return score >= 2 and ("domain" in q or "history" in q or "policy" in q) + +def _should_execute_policy_scan(request: AssistRequest) -> bool: + pref = (request.execution_preference or "auto").strip().lower() + if pref == "off": + return False + if pref == "force": + return True + return _is_policy_domain_query(request.query) + + +def _extract_domain(value: str | None) -> str | None: + if not value: + return None + text = value.strip() + if not text: + return None + + try: + parsed = urlparse(text) + if parsed.netloc: + return parsed.netloc.lower() + except Exception: + pass + + m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text) + return m.group(0).lower() if m else None + + +def _dataset_score(ds: Dataset) -> int: + score = 0 + name = (ds.name or "").lower() + cols_l = {c.lower() for c in (ds.column_schema or {}).keys()} + norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()} + + for h in WEB_DATASET_HINTS: + if h in name: + score += 2 + if h in cols_l: + score += 3 + if h in norm_vals_l: + score += 3 + + if "visited_url" in cols_l or "url" in cols_l: + score += 8 + if "user" in cols_l or "username" in cols_l: + score += 2 + if "clientid" in cols_l or "fqdn" in cols_l: + score += 2 + if (ds.row_count or 0) > 0: + score += 1 + + return score + + +async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict: + scanner = KeywordScanner(db) + + theme_result = await db.execute( + select(KeywordTheme).where( + KeywordTheme.enabled == True, # noqa: E712 + KeywordTheme.name.in_(list(POLICY_THEME_NAMES)), + ) + ) + themes = list(theme_result.scalars().all()) + theme_ids = [t.id for t in themes] + theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES) + + ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"])) + if request.hunt_id: + ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id) + ds_result = await db.execute(ds_query) + candidates = list(ds_result.scalars().all()) + + if request.dataset_name: + needle = request.dataset_name.lower().strip() + candidates = [d for d in candidates if needle in (d.name or "").lower()] + + scored = sorted( + ((d, _dataset_score(d)) for d in candidates), + key=lambda x: x[1], + reverse=True, + ) + selected = [d for d, s in scored if s > 0][:8] + dataset_ids = [d.id for d in selected] + + if not dataset_ids: + return { + "mode": "policy_scan", + "themes": theme_names, + "datasets_scanned": 0, + "dataset_names": [], + "total_hits": 0, + "policy_hits": 0, + "top_user_hosts": [], + "top_domains": [], + "sample_hits": [], + "note": "No suitable browser/web-history datasets found in current scope.", + } + + result = await scanner.scan( + dataset_ids=dataset_ids, + theme_ids=theme_ids or None, + scan_hunts=False, + scan_annotations=False, + scan_messages=False, + ) + hits = result.get("hits", []) + + user_host_counter = Counter() + domain_counter = Counter() + + for h in hits: + user = h.get("username") or "(unknown-user)" + host = h.get("hostname") or "(unknown-host)" + user_host_counter[f"{user}|{host}"] += 1 + + dom = _extract_domain(h.get("matched_value")) + if dom: + domain_counter[dom] += 1 + + top_user_hosts = [ + {"user_host": k, "count": v} + for k, v in user_host_counter.most_common(10) + ] + top_domains = [ + {"domain": k, "count": v} + for k, v in domain_counter.most_common(10) + ] + + return { + "mode": "policy_scan", + "themes": theme_names, + "datasets_scanned": len(dataset_ids), + "dataset_names": [d.name for d in selected], + "total_hits": int(result.get("total_hits", 0)), + "policy_hits": int(result.get("total_hits", 0)), + "rows_scanned": int(result.get("rows_scanned", 0)), + "top_user_hosts": top_user_hosts, + "top_domains": top_domains, + "sample_hits": hits[:20], + } + + +# Routes @router.post( @@ -84,6 +252,76 @@ async def agent_assist( db: AsyncSession = Depends(get_db), ) -> AssistResponseModel: try: + # Deterministic execution mode for policy-domain investigations. + if _should_execute_policy_scan(request): + t0 = time.monotonic() + exec_payload = await _run_policy_domain_execution(request, db) + latency_ms = int((time.monotonic() - t0) * 1000) + + policy_hits = exec_payload.get("policy_hits", 0) + datasets_scanned = exec_payload.get("datasets_scanned", 0) + + if policy_hits > 0: + guidance = ( + f"Policy-violation scan complete: {policy_hits} hits across " + f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included " + f"in execution results for triage." + ) + confidence = 0.95 + caveats = "Keyword-based matching can include false positives; validate with full URL context." + else: + guidance = ( + f"No policy-violation hits found in current scope " + f"({datasets_scanned} dataset(s) scanned)." + ) + confidence = 0.9 + caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets." + + response = AssistResponseModel( + guidance=guidance, + confidence=confidence, + suggested_pivots=["username", "hostname", "domain", "dataset_name"], + suggested_filters=[ + "theme_name in ['Adult Content','Gambling','Downloads / Piracy']", + "username != null", + "hostname != null", + ], + caveats=caveats, + reasoning=( + "Intent matched policy-domain investigation; executed local keyword scan pipeline." + if _is_policy_domain_query(request.query) + else "Execution mode was forced by user preference; ran policy-domain scan pipeline." + ), + sans_references=["SANS FOR508", "SANS SEC504"], + model_used="execution:keyword_scanner", + node_used="local", + latency_ms=latency_ms, + execution=exec_payload, + ) + + conv_id = request.conversation_id + if conv_id or request.hunt_id: + conv_id = await _persist_conversation( + db, + conv_id, + request, + AgentResponse( + guidance=response.guidance, + confidence=response.confidence, + suggested_pivots=response.suggested_pivots, + suggested_filters=response.suggested_filters, + caveats=response.caveats, + reasoning=response.reasoning, + sans_references=response.sans_references, + model_used=response.model_used, + node_used=response.node_used, + latency_ms=response.latency_ms, + ), + ) + response.conversation_id = conv_id + + return response + agent = get_agent() context = AgentContext( query=request.query, @@ -97,6 +335,7 @@ async def agent_assist( enrichment_summary=request.enrichment_summary, mode=request.mode, model_override=request.model_override, + learning_mode=request.learning_mode, ) response = await agent.assist(context) @@ -129,6 +368,7 @@ async def agent_assist( } for p in response.perspectives ] if response.perspectives else None, + execution=None, conversation_id=conv_id, ) @@ -208,7 +448,7 @@ async def list_models(): } -# ── Conversation persistence ────────────────────────────────────────── +# Conversation persistence async def _persist_conversation( @@ -263,3 +503,4 @@ async def _persist_conversation( await db.flush() return conv.id + diff --git a/backend/app/api/routes/analysis.py b/backend/app/api/routes/analysis.py index ff6f7fc..03d3178 100644 --- a/backend/app/api/routes/analysis.py +++ b/backend/app/api/routes/analysis.py @@ -381,6 +381,10 @@ async def submit_job( detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}", ) + if not job_queue.can_accept(): + raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.") + if not job_queue.can_accept(): + raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.") job = job_queue.submit(jt, **params) return {"job_id": job.id, "status": job.status.value, "job_type": job_type} diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py index 2537cb4..076b4a0 100644 --- a/backend/app/api/routes/auth.py +++ b/backend/app/api/routes/auth.py @@ -1,4 +1,4 @@ -"""API routes for authentication — register, login, refresh, profile.""" +"""API routes for authentication — register, login, refresh, profile.""" import logging @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/auth", tags=["auth"]) -# ── Request / Response models ───────────────────────────────────────── +# ── Request / Response models ───────────────────────────────────────── class RegisterRequest(BaseModel): @@ -57,7 +57,7 @@ class AuthResponse(BaseModel): tokens: TokenPair -# ── Routes ──────────────────────────────────────────────────────────── +# ── Routes ──────────────────────────────────────────────────────────── @router.post( @@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): user = User( username=body.username, email=body.email, - password_hash=hash_password(body.password), + hashed_password=hash_password(body.password), display_name=body.display_name or body.username, role="analyst", # Default role ) @@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): result = await db.execute(select(User).where(User.username == body.username)) user = result.scalar_one_or_none() - if not user or not user.password_hash: + if not user or not user.hashed_password: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", ) - if not verify_password(body.password, user.password_hash): + if not verify_password(body.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", @@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db) if token_data.type != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token type — use refresh token", + detail="Invalid token type — use refresh token", ) result = await db.execute(select(User).where(User.id == token_data.sub)) @@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)): is_active=user.is_active, created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at), ) + diff --git a/backend/app/api/routes/datasets.py b/backend/app/api/routes/datasets.py index 154d6de..899a484 100644 --- a/backend/app/api/routes/datasets.py +++ b/backend/app/api/routes/datasets.py @@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.db import get_db +from app.db.models import ProcessingTask from app.db.repositories.datasets import DatasetRepository from app.services.csv_parser import parse_csv_bytes, infer_column_types from app.services.normalizer import ( @@ -18,15 +19,20 @@ from app.services.normalizer import ( detect_ioc_columns, detect_time_range, ) +from app.services.artifact_classifier import classify_artifact, get_artifact_category logger = logging.getLogger(__name__) +from app.services.job_queue import job_queue, JobType +from app.services.host_inventory import inventory_cache +from app.services.scanner import keyword_scan_cache + router = APIRouter(prefix="/api/datasets", tags=["datasets"]) ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"} -# ── Response models ─────────────────────────────────────────────────── +# -- Response models -- class DatasetSummary(BaseModel): @@ -43,6 +49,8 @@ class DatasetSummary(BaseModel): delimiter: str | None = None time_range_start: str | None = None time_range_end: str | None = None + artifact_type: str | None = None + processing_status: str | None = None hunt_id: str | None = None created_at: str @@ -67,10 +75,13 @@ class UploadResponse(BaseModel): column_types: dict normalized_columns: dict ioc_columns: dict + artifact_type: str | None = None + processing_status: str + jobs_queued: list[str] message: str -# ── Routes ──────────────────────────────────────────────────────────── +# -- Routes -- @router.post( @@ -78,7 +89,7 @@ class UploadResponse(BaseModel): response_model=UploadResponse, summary="Upload a CSV dataset", description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, " - "IOCs auto-detected, and rows stored in the database.", + "IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.", ) async def upload_dataset( file: UploadFile = File(...), @@ -87,7 +98,7 @@ async def upload_dataset( hunt_id: str | None = Query(None, description="Hunt ID to associate with"), db: AsyncSession = Depends(get_db), ): - """Upload and parse a CSV dataset.""" + """Upload and parse a CSV dataset, then trigger full processing pipeline.""" # Validate file if not file.filename: raise HTTPException(status_code=400, detail="No filename provided") @@ -136,7 +147,12 @@ async def upload_dataset( # Detect time range time_start, time_end = detect_time_range(rows, column_mapping) - # Store in DB + # Classify artifact type from column headers + artifact_type = classify_artifact(columns) + artifact_category = get_artifact_category(artifact_type) + logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})") + + # Store in DB with processing_status = "processing" repo = DatasetRepository(db) dataset = await repo.create_dataset( name=name or Path(file.filename).stem, @@ -152,6 +168,8 @@ async def upload_dataset( time_range_start=time_start, time_range_end=time_end, hunt_id=hunt_id, + artifact_type=artifact_type, + processing_status="processing", ) await repo.bulk_insert_rows( @@ -162,9 +180,88 @@ async def upload_dataset( logger.info( f"Uploaded dataset '{dataset.name}': {len(rows)} rows, " - f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected" + f"{len(columns)} columns, {len(ioc_columns)} IOC columns, " + f"artifact={artifact_type}" ) + # -- Queue full processing pipeline -- + jobs_queued = [] + + task_rows: list[ProcessingTask] = [] + + # 1. AI Triage (chains to HOST_PROFILE automatically on completion) + triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id) + jobs_queued.append("triage") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=triage_job.id, + stage="triage", + status="queued", + progress=0.0, + message="Queued", + )) + + # 2. Anomaly detection (embedding-based outlier detection) + anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id) + jobs_queued.append("anomaly") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=anomaly_job.id, + stage="anomaly", + status="queued", + progress=0.0, + message="Queued", + )) + + # 3. AUP keyword scan + kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id) + jobs_queued.append("keyword_scan") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=kw_job.id, + stage="keyword_scan", + status="queued", + progress=0.0, + message="Queued", + )) + + # 4. IOC extraction + ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id) + jobs_queued.append("ioc_extract") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=ioc_job.id, + stage="ioc_extract", + status="queued", + progress=0.0, + message="Queued", + )) + + # 5. Host inventory (network map) - requires hunt_id + if hunt_id: + inventory_cache.invalidate(hunt_id) + inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + jobs_queued.append("host_inventory") + task_rows.append(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset.id, + job_id=inv_job.id, + stage="host_inventory", + status="queued", + progress=0.0, + message="Queued", + )) + + if task_rows: + db.add_all(task_rows) + await db.flush() + + logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}") + return UploadResponse( id=dataset.id, name=dataset.name, @@ -173,7 +270,10 @@ async def upload_dataset( column_types=column_types, normalized_columns=column_mapping, ioc_columns=ioc_columns, - message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected", + artifact_type=artifact_type, + processing_status="processing", + jobs_queued=jobs_queued, + message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.", ) @@ -208,6 +308,8 @@ async def list_datasets( delimiter=ds.delimiter, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, + artifact_type=ds.artifact_type, + processing_status=ds.processing_status, hunt_id=ds.hunt_id, created_at=ds.created_at.isoformat(), ) @@ -244,6 +346,8 @@ async def get_dataset( delimiter=ds.delimiter, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, + artifact_type=ds.artifact_type, + processing_status=ds.processing_status, hunt_id=ds.hunt_id, created_at=ds.created_at.isoformat(), ) @@ -292,4 +396,5 @@ async def delete_dataset( deleted = await repo.delete_dataset(dataset_id) if not deleted: raise HTTPException(status_code=404, detail="Dataset not found") + keyword_scan_cache.invalidate_dataset(dataset_id) return {"message": "Dataset deleted", "id": dataset_id} diff --git a/backend/app/api/routes/hunts.py b/backend/app/api/routes/hunts.py index f8c0915..2183b16 100644 --- a/backend/app/api/routes/hunts.py +++ b/backend/app/api/routes/hunts.py @@ -8,16 +8,15 @@ from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.db import get_db -from app.db.models import Hunt, Conversation, Message +from app.db.models import Hunt, Dataset, ProcessingTask +from app.services.job_queue import job_queue +from app.services.host_inventory import inventory_cache logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/hunts", tags=["hunts"]) -# ── Models ──────────────────────────────────────────────────────────── - - class HuntCreate(BaseModel): name: str = Field(..., max_length=256) description: str | None = None @@ -26,7 +25,7 @@ class HuntCreate(BaseModel): class HuntUpdate(BaseModel): name: str | None = None description: str | None = None - status: str | None = None # active | closed | archived + status: str | None = None class HuntResponse(BaseModel): @@ -46,7 +45,18 @@ class HuntListResponse(BaseModel): total: int -# ── Routes ──────────────────────────────────────────────────────────── +class HuntProgressResponse(BaseModel): + hunt_id: str + status: str + progress_percent: float + dataset_total: int + dataset_completed: int + dataset_processing: int + dataset_errors: int + active_jobs: int + queued_jobs: int + network_status: str + stages: dict @router.post("", response_model=HuntResponse, summary="Create a new hunt") @@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)): ) +@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress") +async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)): + hunt = await db.get(Hunt, hunt_id) + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + + ds_rows = await db.execute( + select(Dataset.id, Dataset.processing_status) + .where(Dataset.hunt_id == hunt_id) + ) + datasets = ds_rows.all() + dataset_ids = {row[0] for row in datasets} + + dataset_total = len(datasets) + dataset_completed = sum(1 for _, st in datasets if st == "completed") + dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors") + dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors) + + jobs = job_queue.list_jobs(limit=5000) + relevant_jobs = [ + j for j in jobs + if j.get("params", {}).get("hunt_id") == hunt_id + or j.get("params", {}).get("dataset_id") in dataset_ids + ] + active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running") + queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued") + + task_rows = await db.execute( + select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress) + .where(ProcessingTask.hunt_id == hunt_id) + ) + tasks = task_rows.all() + + task_total = len(tasks) + task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled")) + task_running = sum(1 for _, st, _ in tasks if st == "running") + task_queued = sum(1 for _, st, _ in tasks if st == "queued") + task_ratio = (task_done / task_total) if task_total > 0 else None + + active_jobs = max(active_jobs_mem, task_running) + queued_jobs = max(queued_jobs_mem, task_queued) + + stage_rollup: dict[str, dict] = {} + for stage, status, progress in tasks: + bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0}) + bucket["total"] += 1 + if status in ("completed", "failed", "cancelled"): + bucket["done"] += 1 + elif status == "running": + bucket["running"] += 1 + elif status == "queued": + bucket["queued"] += 1 + bucket["progress_sum"] += float(progress or 0.0) + + for stage_name, bucket in stage_rollup.items(): + total = max(1, bucket["total"]) + bucket["percent"] = round(bucket["progress_sum"] / total, 1) + + if inventory_cache.get(hunt_id) is not None: + network_status = "ready" + network_ratio = 1.0 + elif inventory_cache.is_building(hunt_id): + network_status = "building" + network_ratio = 0.5 + else: + network_status = "none" + network_ratio = 0.0 + + dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0 + if task_ratio is None: + overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15)) + else: + overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15)) + progress_percent = round(overall_ratio * 100.0, 1) + + status = "ready" + if dataset_total == 0: + status = "idle" + elif progress_percent < 100: + status = "processing" + + stages = { + "datasets": { + "total": dataset_total, + "completed": dataset_completed, + "processing": dataset_processing, + "errors": dataset_errors, + "percent": round(dataset_ratio * 100.0, 1), + }, + "network": { + "status": network_status, + "percent": round(network_ratio * 100.0, 1), + }, + "jobs": { + "active": active_jobs, + "queued": queued_jobs, + "total_seen": len(relevant_jobs), + "task_total": task_total, + "task_done": task_done, + "task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None, + }, + "task_stages": stage_rollup, + } + + return HuntProgressResponse( + hunt_id=hunt_id, + status=status, + progress_percent=progress_percent, + dataset_total=dataset_total, + dataset_completed=dataset_completed, + dataset_processing=dataset_processing, + dataset_errors=dataset_errors, + active_jobs=active_jobs, + queued_jobs=queued_jobs, + network_status=network_status, + stages=stages, + ) + + @router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt") async def update_hunt( hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db) diff --git a/backend/app/api/routes/keywords.py b/backend/app/api/routes/keywords.py index bd4fca5..b048472 100644 --- a/backend/app/api/routes/keywords.py +++ b/backend/app/api/routes/keywords.py @@ -1,25 +1,21 @@ """API routes for AUP keyword themes, keyword CRUD, and scanning.""" import logging -from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field -from sqlalchemy import select, func, delete +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db import get_db from app.db.models import KeywordTheme, Keyword -from app.services.scanner import KeywordScanner +from app.services.scanner import KeywordScanner, keyword_scan_cache logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/keywords", tags=["keywords"]) -# ── Pydantic schemas ────────────────────────────────────────────────── - - class ThemeCreate(BaseModel): name: str = Field(..., min_length=1, max_length=128) color: str = Field(default="#9e9e9e", max_length=16) @@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel): class ScanRequest(BaseModel): - dataset_ids: list[str] | None = None # None → all datasets - theme_ids: list[str] | None = None # None → all enabled themes - scan_hunts: bool = True - scan_annotations: bool = True - scan_messages: bool = True + dataset_ids: list[str] | None = None + theme_ids: list[str] | None = None + scan_hunts: bool = False + scan_annotations: bool = False + scan_messages: bool = False + prefer_cache: bool = True + force_rescan: bool = False class ScanHit(BaseModel): theme_name: str theme_color: str keyword: str - source_type: str # dataset_row | hunt | annotation | message + source_type: str source_id: str | int field: str matched_value: str row_index: int | None = None dataset_name: str | None = None + hostname: str | None = None + username: str | None = None class ScanResponse(BaseModel): @@ -92,9 +92,9 @@ class ScanResponse(BaseModel): themes_scanned: int keywords_scanned: int rows_scanned: int - - -# ── Helpers ─────────────────────────────────────────────────────────── + cache_used: bool = False + cache_status: str = "miss" + cached_at: str | None = None def _theme_to_out(t: KeywordTheme) -> ThemeOut: @@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut: ) -# ── Theme CRUD ──────────────────────────────────────────────────────── +def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict: + hits: list[dict] = [] + total_rows = 0 + cached_at: str | None = None + + for entry in entries: + result = entry["result"] + total_rows += int(result.get("rows_scanned", 0) or 0) + if entry.get("built_at"): + if not cached_at or entry["built_at"] > cached_at: + cached_at = entry["built_at"] + for h in result.get("hits", []): + if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names: + continue + hits.append(h) + + return { + "total_hits": len(hits), + "hits": hits, + "rows_scanned": total_rows, + "cached_at": cached_at, + } @router.get("/themes", response_model=ThemeListResponse) async def list_themes(db: AsyncSession = Depends(get_db)): - """List all keyword themes with their keywords.""" - result = await db.execute( - select(KeywordTheme).order_by(KeywordTheme.name) - ) + result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name)) themes = result.scalars().all() - return ThemeListResponse( - themes=[_theme_to_out(t) for t in themes], - total=len(themes), - ) + return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes)) @router.post("/themes", response_model=ThemeOut, status_code=201) async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)): - """Create a new keyword theme.""" - exists = await db.scalar( - select(KeywordTheme.id).where(KeywordTheme.name == body.name) - ) + exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name)) if exists: raise HTTPException(409, f"Theme '{body.name}' already exists") theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled) db.add(theme) await db.flush() await db.refresh(theme) + keyword_scan_cache.clear() return _theme_to_out(theme) @router.put("/themes/{theme_id}", response_model=ThemeOut) async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)): - """Update theme name, color, or enabled status.""" theme = await db.get(KeywordTheme, theme_id) if not theme: raise HTTPException(404, "Theme not found") if body.name is not None: - # check uniqueness dup = await db.scalar( - select(KeywordTheme.id).where( - KeywordTheme.name == body.name, KeywordTheme.id != theme_id - ) + select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id) ) if dup: raise HTTPException(409, f"Theme '{body.name}' already exists") @@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe theme.enabled = body.enabled await db.flush() await db.refresh(theme) + keyword_scan_cache.clear() return _theme_to_out(theme) @router.delete("/themes/{theme_id}", status_code=204) async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)): - """Delete a theme and all its keywords.""" theme = await db.get(KeywordTheme, theme_id) if not theme: raise HTTPException(404, "Theme not found") await db.delete(theme) - - -# ── Keyword CRUD ────────────────────────────────────────────────────── + keyword_scan_cache.clear() @router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201) async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)): - """Add a single keyword to a theme.""" theme = await db.get(KeywordTheme, theme_id) if not theme: raise HTTPException(404, "Theme not found") @@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep db.add(kw) await db.flush() await db.refresh(kw) + keyword_scan_cache.clear() return KeywordOut( id=kw.id, theme_id=kw.theme_id, value=kw.value, is_regex=kw.is_regex, created_at=kw.created_at.isoformat(), @@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep @router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201) async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)): - """Add multiple keywords to a theme at once.""" theme = await db.get(KeywordTheme, theme_id) if not theme: raise HTTPException(404, "Theme not found") @@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex)) added += 1 await db.flush() + keyword_scan_cache.clear() return {"added": added, "theme_id": theme_id} @router.delete("/keywords/{keyword_id}", status_code=204) async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)): - """Delete a single keyword.""" kw = await db.get(Keyword, keyword_id) if not kw: raise HTTPException(404, "Keyword not found") await db.delete(kw) - - -# ── Scan endpoints ──────────────────────────────────────────────────── + keyword_scan_cache.clear() @router.post("/scan", response_model=ScanResponse) async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): - """Run AUP keyword scan across selected data sources.""" scanner = KeywordScanner(db) + + if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages: + return { + "total_hits": 0, + "hits": [], + "themes_scanned": 0, + "keywords_scanned": 0, + "rows_scanned": 0, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } + + can_use_cache = ( + body.prefer_cache + and not body.force_rescan + and bool(body.dataset_ids) + and not body.scan_hunts + and not body.scan_annotations + and not body.scan_messages + ) + + if can_use_cache: + themes = await scanner._load_themes(body.theme_ids) + allowed_theme_names = {t.name for t in themes} + keywords_scanned = sum(len(theme.keywords) for theme in themes) + + cached_entries: list[dict] = [] + missing: list[str] = [] + for dataset_id in (body.dataset_ids or []): + entry = keyword_scan_cache.get(dataset_id) + if not entry: + missing.append(dataset_id) + continue + cached_entries.append({"result": entry.result, "built_at": entry.built_at}) + + if not missing and cached_entries: + merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": True, + "cache_status": "hit", + "cached_at": merged["cached_at"], + } + + if missing: + partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids) + merged = _merge_cached_results( + cached_entries + [{"result": partial, "built_at": None}], + allowed_theme_names if body.theme_ids else None, + ) + return { + "total_hits": merged["total_hits"], + "hits": merged["hits"], + "themes_scanned": len(themes), + "keywords_scanned": keywords_scanned, + "rows_scanned": merged["rows_scanned"], + "cache_used": len(cached_entries) > 0, + "cache_status": "partial" if cached_entries else "miss", + "cached_at": merged["cached_at"], + } + result = await scanner.scan( dataset_ids=body.dataset_ids, theme_ids=body.theme_ids, @@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): scan_annotations=body.scan_annotations, scan_messages=body.scan_messages, ) - return result + + return { + **result, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } @router.get("/scan/quick", response_model=ScanResponse) @@ -251,7 +326,22 @@ async def quick_scan( dataset_id: str = Query(..., description="Dataset to scan"), db: AsyncSession = Depends(get_db), ): - """Quick scan a single dataset with all enabled themes.""" + entry = keyword_scan_cache.get(dataset_id) + if entry is not None: + result = entry.result + return { + **result, + "cache_used": True, + "cache_status": "hit", + "cached_at": entry.built_at, + } + scanner = KeywordScanner(db) result = await scanner.scan(dataset_ids=[dataset_id]) - return result + keyword_scan_cache.put(dataset_id, result) + return { + **result, + "cache_used": False, + "cache_status": "miss", + "cached_at": None, + } diff --git a/backend/app/api/routes/mitre.py b/backend/app/api/routes/mitre.py new file mode 100644 index 0000000..35bf28a --- /dev/null +++ b/backend/app/api/routes/mitre.py @@ -0,0 +1,146 @@ +"""API routes for MITRE ATT&CK coverage visualization.""" + +import logging +from collections import defaultdict + +from fastapi import APIRouter, Depends +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import ( + TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/mitre", tags=["mitre"]) + +# Canonical MITRE ATT&CK tactics in kill-chain order +TACTICS = [ + "Reconnaissance", "Resource Development", "Initial Access", + "Execution", "Persistence", "Privilege Escalation", + "Defense Evasion", "Credential Access", "Discovery", + "Lateral Movement", "Collection", "Command and Control", + "Exfiltration", "Impact", +] + +# Simplified technique-to-tactic mapping (top techniques) +TECHNIQUE_TACTIC: dict[str, str] = { + "T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution", + "T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution", + "T1053": "Persistence", "T1053.005": "Persistence", + "T1547": "Persistence", "T1547.001": "Persistence", + "T1543": "Persistence", "T1543.003": "Persistence", + "T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation", + "T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation", + "T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation", + "T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation", + "T1070": "Defense Evasion", "T1070.001": "Defense Evasion", + "T1070.004": "Defense Evasion", + "T1036": "Defense Evasion", "T1036.005": "Defense Evasion", + "T1027": "Defense Evasion", "T1140": "Defense Evasion", + "T1218": "Defense Evasion", "T1218.011": "Defense Evasion", + "T1003": "Credential Access", "T1003.001": "Credential Access", + "T1110": "Credential Access", "T1558": "Credential Access", + "T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery", + "T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery", + "T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery", + "T1021": "Lateral Movement", "T1021.001": "Lateral Movement", + "T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement", + "T1570": "Lateral Movement", + "T1560": "Collection", "T1074": "Collection", "T1005": "Collection", + "T1071": "Command and Control", "T1071.001": "Command and Control", + "T1105": "Command and Control", "T1572": "Command and Control", + "T1095": "Command and Control", + "T1048": "Exfiltration", "T1041": "Exfiltration", + "T1486": "Impact", "T1490": "Impact", "T1489": "Impact", + "T1566": "Initial Access", "T1566.001": "Initial Access", + "T1566.002": "Initial Access", + "T1190": "Initial Access", "T1133": "Initial Access", + "T1195": "Initial Access", "T1195.002": "Initial Access", +} + + +def _get_tactic(technique_id: str) -> str: + """Map a technique ID to its tactic.""" + tech = technique_id.strip().upper() + if tech in TECHNIQUE_TACTIC: + return TECHNIQUE_TACTIC[tech] + # Try parent technique + if "." in tech: + parent = tech.split(".")[0] + if parent in TECHNIQUE_TACTIC: + return TECHNIQUE_TACTIC[parent] + return "Unknown" + + +@router.get("/coverage") +async def get_mitre_coverage( + hunt_id: str | None = None, + db: AsyncSession = Depends(get_db), +): + """Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports.""" + + techniques: dict[str, dict] = {} + + # Collect from triage results + triage_q = select(TriageResult) + if hunt_id: + triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id) + result = await db.execute(triage_q.limit(500)) + for t in result.scalars().all(): + for tech in (t.mitre_techniques or []): + if tech not in techniques: + techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0} + techniques[tech]["count"] += 1 + techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score}) + + # Collect from host profiles + profile_q = select(HostProfile) + if hunt_id: + profile_q = profile_q.where(HostProfile.hunt_id == hunt_id) + result = await db.execute(profile_q.limit(200)) + for p in result.scalars().all(): + for tech in (p.mitre_techniques or []): + if tech not in techniques: + techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0} + techniques[tech]["count"] += 1 + techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname}) + + # Collect from hypotheses + hyp_q = select(Hypothesis) + if hunt_id: + hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id) + result = await db.execute(hyp_q.limit(200)) + for h in result.scalars().all(): + tech = h.mitre_technique + if tech: + if tech not in techniques: + techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0} + techniques[tech]["count"] += 1 + techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title}) + + # Build tactic-grouped response + tactic_groups: dict[str, list] = {t: [] for t in TACTICS} + tactic_groups["Unknown"] = [] + for tech in techniques.values(): + tactic = tech["tactic"] + if tactic not in tactic_groups: + tactic_groups[tactic] = [] + tactic_groups[tactic].append(tech) + + total_techniques = len(techniques) + total_detections = sum(t["count"] for t in techniques.values()) + + return { + "tactics": TACTICS, + "technique_count": total_techniques, + "detection_count": total_detections, + "tactic_coverage": { + t: {"techniques": techs, "count": len(techs)} + for t, techs in tactic_groups.items() + if techs + }, + "all_techniques": list(techniques.values()), + } diff --git a/backend/app/api/routes/network.py b/backend/app/api/routes/network.py index 65d47ad..7dea039 100644 --- a/backend/app/api/routes/network.py +++ b/backend/app/api/routes/network.py @@ -1,12 +1,15 @@ -"""Network topology API - host inventory endpoint.""" +"""Network topology API - host inventory endpoint with background caching.""" import logging from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import JSONResponse from sqlalchemy.ext.asyncio import AsyncSession +from app.config import settings from app.db import get_db -from app.services.host_inventory import build_host_inventory +from app.services.host_inventory import build_host_inventory, inventory_cache +from app.services.job_queue import job_queue, JobType logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/network", tags=["network"]) @@ -15,14 +18,158 @@ router = APIRouter(prefix="/api/network", tags=["network"]) @router.get("/host-inventory") async def get_host_inventory( hunt_id: str = Query(..., description="Hunt ID to build inventory for"), + force: bool = Query(False, description="Force rebuild, ignoring cache"), db: AsyncSession = Depends(get_db), ): - """Build a deduplicated host inventory from all datasets in a hunt. + """Return a deduplicated host inventory for the hunt. - Returns unique hosts with hostname, IPs, OS, logged-in users, and - network connections derived from netstat/connection data. + Returns instantly from cache if available (pre-built after upload or on startup). + If cache is cold, triggers a background build and returns 202 so the + frontend can poll /inventory-status and re-request when ready. """ - result = await build_host_inventory(hunt_id, db) - if result["stats"]["total_hosts"] == 0: - return result - return result \ No newline at end of file + # Force rebuild: invalidate cache, queue background job, return 202 + if force: + inventory_cache.invalidate(hunt_id) + if not inventory_cache.is_building(hunt_id): + if job_queue.is_backlogged(): + return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"}) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return JSONResponse( + status_code=202, + content={"status": "building", "message": "Rebuild queued"}, + ) + + # Try cache first + cached = inventory_cache.get(hunt_id) + if cached is not None: + logger.info(f"Serving cached host inventory for {hunt_id}") + return cached + + # Cache miss: trigger background build instead of blocking for 90+ seconds + if not inventory_cache.is_building(hunt_id): + logger.info(f"Cache miss for {hunt_id}, triggering background build") + if job_queue.is_backlogged(): + return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"}) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + + return JSONResponse( + status_code=202, + content={"status": "building", "message": "Inventory is being built in the background"}, + ) + + +def _build_summary(inv: dict, top_n: int = 20) -> dict: + hosts = inv.get("hosts", []) + conns = inv.get("connections", []) + top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n] + top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n] + return { + "stats": inv.get("stats", {}), + "top_hosts": [ + { + "id": h.get("id"), + "hostname": h.get("hostname"), + "row_count": h.get("row_count", 0), + "ip_count": len(h.get("ips", [])), + "user_count": len(h.get("users", [])), + } + for h in top_hosts + ], + "top_edges": top_edges, + } + + +def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict: + hosts = inv.get("hosts", []) + conns = inv.get("connections", []) + + max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS)) + max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES)) + + if node_id: + rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id] + rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges] + ids = {node_id} + for c in rel_edges: + ids.add(c.get("source")) + ids.add(c.get("target")) + rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts] + else: + rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts] + allowed = {h.get("id") for h in rel_hosts} + rel_edges = [ + c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True) + if c.get("source") in allowed and c.get("target") in allowed + ][:max_edges] + + return { + "hosts": rel_hosts, + "connections": rel_edges, + "stats": { + **inv.get("stats", {}), + "subgraph_hosts": len(rel_hosts), + "subgraph_connections": len(rel_edges), + "truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns), + }, + } + + +@router.get("/summary") +async def get_inventory_summary( + hunt_id: str = Query(..., description="Hunt ID"), + top_n: int = Query(20, ge=1, le=200), +): + """Return a lightweight summary view for large hunts.""" + cached = inventory_cache.get(hunt_id) + if cached is None: + if not inventory_cache.is_building(hunt_id): + if job_queue.is_backlogged(): + return JSONResponse( + status_code=202, + content={"status": "deferred", "message": "Queue busy, retry shortly"}, + ) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return JSONResponse(status_code=202, content={"status": "building"}) + return _build_summary(cached, top_n=top_n) + + +@router.get("/subgraph") +async def get_inventory_subgraph( + hunt_id: str = Query(..., description="Hunt ID"), + node_id: str | None = Query(None, description="Optional focal node"), + max_hosts: int = Query(200, ge=1, le=5000), + max_edges: int = Query(1500, ge=1, le=20000), +): + """Return a bounded subgraph for scale-safe rendering.""" + cached = inventory_cache.get(hunt_id) + if cached is None: + if not inventory_cache.is_building(hunt_id): + if job_queue.is_backlogged(): + return JSONResponse( + status_code=202, + content={"status": "deferred", "message": "Queue busy, retry shortly"}, + ) + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return JSONResponse(status_code=202, content={"status": "building"}) + return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges) + + +@router.get("/inventory-status") +async def get_inventory_status( + hunt_id: str = Query(..., description="Hunt ID to check"), +): + """Check whether pre-computed host inventory is ready for a hunt. + + Returns: { status: "ready" | "building" | "none" } + """ + return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)} + + +@router.post("/rebuild-inventory") +async def trigger_rebuild( + hunt_id: str = Query(..., description="Hunt to rebuild inventory for"), +): + """Trigger a background rebuild of the host inventory cache.""" + inventory_cache.invalidate(hunt_id) + job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id) + return {"job_id": job.id, "status": "queued"} diff --git a/backend/app/api/routes/playbooks.py b/backend/app/api/routes/playbooks.py new file mode 100644 index 0000000..e2b45c0 --- /dev/null +++ b/backend/app/api/routes/playbooks.py @@ -0,0 +1,217 @@ +"""API routes for investigation playbooks.""" + +import logging +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import Playbook, PlaybookStep + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/playbooks", tags=["playbooks"]) + + +# -- Request / Response schemas --- + +class StepCreate(BaseModel): + title: str + description: str | None = None + step_type: str = "manual" + target_route: str | None = None + +class PlaybookCreate(BaseModel): + name: str + description: str | None = None + hunt_id: str | None = None + is_template: bool = False + steps: list[StepCreate] = [] + +class PlaybookUpdate(BaseModel): + name: str | None = None + description: str | None = None + status: str | None = None + +class StepUpdate(BaseModel): + is_completed: bool | None = None + notes: str | None = None + + +# -- Default investigation templates --- + +DEFAULT_TEMPLATES = [ + { + "name": "Standard Threat Hunt", + "description": "Step-by-step investigation workflow for a typical threat hunting engagement.", + "steps": [ + {"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"}, + {"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"}, + {"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"}, + {"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"}, + {"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"}, + {"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"}, + {"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"}, + {"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"}, + {"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"}, + {"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"}, + ], + }, + { + "name": "Incident Response Triage", + "description": "Fast-track workflow for active incident response.", + "steps": [ + {"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"}, + {"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"}, + {"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"}, + {"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"}, + {"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"}, + {"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"}, + ], + }, +] + + +# -- Routes --- + +@router.get("") +async def list_playbooks( + include_templates: bool = True, + hunt_id: str | None = None, + db: AsyncSession = Depends(get_db), +): + q = select(Playbook) + if hunt_id: + q = q.where(Playbook.hunt_id == hunt_id) + if not include_templates: + q = q.where(Playbook.is_template == False) + q = q.order_by(Playbook.created_at.desc()) + result = await db.execute(q.limit(100)) + playbooks = result.scalars().all() + + return {"playbooks": [ + { + "id": p.id, "name": p.name, "description": p.description, + "is_template": p.is_template, "hunt_id": p.hunt_id, + "status": p.status, + "total_steps": len(p.steps), + "completed_steps": sum(1 for s in p.steps if s.is_completed), + "created_at": p.created_at.isoformat() if p.created_at else None, + } + for p in playbooks + ]} + + +@router.get("/templates") +async def get_templates(): + """Return built-in investigation templates.""" + return {"templates": DEFAULT_TEMPLATES} + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)): + pb = Playbook( + name=body.name, + description=body.description, + hunt_id=body.hunt_id, + is_template=body.is_template, + ) + db.add(pb) + await db.flush() + + created_steps = [] + for i, step in enumerate(body.steps): + s = PlaybookStep( + playbook_id=pb.id, + order_index=i, + title=step.title, + description=step.description, + step_type=step.step_type, + target_route=step.target_route, + ) + db.add(s) + created_steps.append(s) + + await db.flush() + + return { + "id": pb.id, "name": pb.name, "description": pb.description, + "hunt_id": pb.hunt_id, "is_template": pb.is_template, + "steps": [ + {"id": s.id, "order_index": s.order_index, "title": s.title, + "description": s.description, "step_type": s.step_type, + "target_route": s.target_route, "is_completed": False} + for s in created_steps + ], + } + + +@router.get("/{playbook_id}") +async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(Playbook).where(Playbook.id == playbook_id)) + pb = result.scalar_one_or_none() + if not pb: + raise HTTPException(status_code=404, detail="Playbook not found") + + return { + "id": pb.id, "name": pb.name, "description": pb.description, + "is_template": pb.is_template, "hunt_id": pb.hunt_id, + "status": pb.status, + "created_at": pb.created_at.isoformat() if pb.created_at else None, + "steps": [ + { + "id": s.id, "order_index": s.order_index, "title": s.title, + "description": s.description, "step_type": s.step_type, + "target_route": s.target_route, + "is_completed": s.is_completed, + "completed_at": s.completed_at.isoformat() if s.completed_at else None, + "notes": s.notes, + } + for s in sorted(pb.steps, key=lambda x: x.order_index) + ], + } + + +@router.put("/{playbook_id}") +async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(Playbook).where(Playbook.id == playbook_id)) + pb = result.scalar_one_or_none() + if not pb: + raise HTTPException(status_code=404, detail="Playbook not found") + + if body.name is not None: + pb.name = body.name + if body.description is not None: + pb.description = body.description + if body.status is not None: + pb.status = body.status + return {"status": "updated"} + + +@router.delete("/{playbook_id}") +async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(Playbook).where(Playbook.id == playbook_id)) + pb = result.scalar_one_or_none() + if not pb: + raise HTTPException(status_code=404, detail="Playbook not found") + await db.delete(pb) + return {"status": "deleted"} + + +@router.put("/steps/{step_id}") +async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id)) + step = result.scalar_one_or_none() + if not step: + raise HTTPException(status_code=404, detail="Step not found") + + if body.is_completed is not None: + step.is_completed = body.is_completed + step.completed_at = datetime.now(timezone.utc) if body.is_completed else None + if body.notes is not None: + step.notes = body.notes + return {"status": "updated", "is_completed": step.is_completed} + diff --git a/backend/app/api/routes/saved_searches.py b/backend/app/api/routes/saved_searches.py new file mode 100644 index 0000000..0a300a8 --- /dev/null +++ b/backend/app/api/routes/saved_searches.py @@ -0,0 +1,164 @@ +"""API routes for saved searches and bookmarked queries.""" + +import logging +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import SavedSearch + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/searches", tags=["saved-searches"]) + + +class SearchCreate(BaseModel): + name: str + description: str | None = None + search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation" + query_params: dict + threshold: float | None = None + + +class SearchUpdate(BaseModel): + name: str | None = None + description: str | None = None + query_params: dict | None = None + threshold: float | None = None + + +@router.get("") +async def list_searches( + search_type: str | None = None, + db: AsyncSession = Depends(get_db), +): + q = select(SavedSearch).order_by(SavedSearch.created_at.desc()) + if search_type: + q = q.where(SavedSearch.search_type == search_type) + result = await db.execute(q.limit(100)) + searches = result.scalars().all() + return {"searches": [ + { + "id": s.id, "name": s.name, "description": s.description, + "search_type": s.search_type, "query_params": s.query_params, + "threshold": s.threshold, + "last_run_at": s.last_run_at.isoformat() if s.last_run_at else None, + "last_result_count": s.last_result_count, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + for s in searches + ]} + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)): + s = SavedSearch( + name=body.name, + description=body.description, + search_type=body.search_type, + query_params=body.query_params, + threshold=body.threshold, + ) + db.add(s) + await db.flush() + return { + "id": s.id, "name": s.name, "search_type": s.search_type, + "query_params": s.query_params, "threshold": s.threshold, + } + + +@router.get("/{search_id}") +async def get_search(search_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id)) + s = result.scalar_one_or_none() + if not s: + raise HTTPException(status_code=404, detail="Saved search not found") + return { + "id": s.id, "name": s.name, "description": s.description, + "search_type": s.search_type, "query_params": s.query_params, + "threshold": s.threshold, + "last_run_at": s.last_run_at.isoformat() if s.last_run_at else None, + "last_result_count": s.last_result_count, + "created_at": s.created_at.isoformat() if s.created_at else None, + } + + +@router.put("/{search_id}") +async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id)) + s = result.scalar_one_or_none() + if not s: + raise HTTPException(status_code=404, detail="Saved search not found") + if body.name is not None: + s.name = body.name + if body.description is not None: + s.description = body.description + if body.query_params is not None: + s.query_params = body.query_params + if body.threshold is not None: + s.threshold = body.threshold + return {"status": "updated"} + + +@router.delete("/{search_id}") +async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)): + result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id)) + s = result.scalar_one_or_none() + if not s: + raise HTTPException(status_code=404, detail="Saved search not found") + await db.delete(s) + return {"status": "deleted"} + + +@router.post("/{search_id}/run") +async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)): + """Execute a saved search and return results with delta from last run.""" + result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id)) + s = result.scalar_one_or_none() + if not s: + raise HTTPException(status_code=404, detail="Saved search not found") + + previous_count = s.last_result_count or 0 + results = [] + count = 0 + + if s.search_type == "ioc_search": + from app.db.models import EnrichmentResult + ioc_value = s.query_params.get("ioc_value", "") + if ioc_value: + q = select(EnrichmentResult).where( + EnrichmentResult.ioc_value.contains(ioc_value) + ) + res = await db.execute(q.limit(100)) + for er in res.scalars().all(): + results.append({ + "ioc_value": er.ioc_value, "ioc_type": er.ioc_type, + "source": er.source, "verdict": er.verdict, + }) + count = len(results) + + elif s.search_type == "keyword_scan": + from app.db.models import KeywordTheme + res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True)) + themes = res.scalars().all() + count = sum(len(t.keywords) for t in themes) + results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes] + + # Update last run metadata + s.last_run_at = datetime.now(timezone.utc) + s.last_result_count = count + + delta = count - previous_count + + return { + "search_id": s.id, "search_name": s.name, + "search_type": s.search_type, + "result_count": count, + "previous_count": previous_count, + "delta": delta, + "results": results[:50], + } diff --git a/backend/app/api/routes/stix_export.py b/backend/app/api/routes/stix_export.py new file mode 100644 index 0000000..24af732 --- /dev/null +++ b/backend/app/api/routes/stix_export.py @@ -0,0 +1,184 @@ +"""STIX 2.1 export endpoint. + +Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a +STIX 2.1 Bundle JSON download. No external dependencies required we +build the JSON directly following the OASIS STIX 2.1 spec. +""" + +import json +import uuid +from datetime import datetime, timezone +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi.responses import Response +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import ( + Hunt, Dataset, Hypothesis, TriageResult, HostProfile, + EnrichmentResult, HuntReport, +) + +router = APIRouter(prefix="/api/export", tags=["export"]) + +STIX_SPEC_VERSION = "2.1" + + +def _stix_id(stype: str) -> str: + return f"{stype}--{uuid.uuid4()}" + + +def _now_iso() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + + +def _build_identity(hunt_name: str) -> dict: + return { + "type": "identity", + "spec_version": STIX_SPEC_VERSION, + "id": _stix_id("identity"), + "created": _now_iso(), + "modified": _now_iso(), + "name": f"ThreatHunt - {hunt_name}", + "identity_class": "system", + } + + +def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict: + pattern_map = { + "ipv4": f"[ipv4-addr:value = '{ioc_value}']", + "ipv6": f"[ipv6-addr:value = '{ioc_value}']", + "domain": f"[domain-name:value = '{ioc_value}']", + "url": f"[url:value = '{ioc_value}']", + "hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']", + "hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']", + "hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']", + "email": f"[email-addr:value = '{ioc_value}']", + } + pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']") + now = _now_iso() + return { + "type": "indicator", + "spec_version": STIX_SPEC_VERSION, + "id": _stix_id("indicator"), + "created": now, + "modified": now, + "name": f"{ioc_type}: {ioc_value}", + "pattern": pattern, + "pattern_type": "stix", + "valid_from": now, + "created_by_ref": identity_id, + "labels": [verdict or "suspicious"], + } + + +def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict: + now = _now_iso() + return { + "type": "attack-pattern", + "spec_version": STIX_SPEC_VERSION, + "id": _stix_id("attack-pattern"), + "created": now, + "modified": now, + "name": technique_id, + "created_by_ref": identity_id, + "external_references": [ + { + "source_name": "mitre-attack", + "external_id": technique_id, + "url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/", + } + ], + } + + +def _hypothesis_to_report(hyp, identity_id: str) -> dict: + now = _now_iso() + return { + "type": "report", + "spec_version": STIX_SPEC_VERSION, + "id": _stix_id("report"), + "created": now, + "modified": now, + "name": hyp.title, + "description": hyp.description or "", + "published": now, + "created_by_ref": identity_id, + "labels": ["threat-hunt-hypothesis"], + "object_refs": [], + } + + +@router.get("/stix/{hunt_id}") +async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)): + """Export hunt data as a STIX 2.1 Bundle JSON file.""" + # Fetch hunt + hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none() + if not hunt: + raise HTTPException(404, "Hunt not found") + + identity = _build_identity(hunt.name) + objects: list[dict] = [identity] + seen_techniques: set[str] = set() + seen_iocs: set[str] = set() + + # Gather IOCs from enrichment results for hunt's datasets + datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id)) + ds_ids = [r[0] for r in datasets_q.all()] + + if ds_ids: + enrichments = (await db.execute( + select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids)) + )).scalars().all() + for e in enrichments: + key = f"{e.ioc_type}:{e.ioc_value}" + if key not in seen_iocs: + seen_iocs.add(key) + objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict)) + + # Gather techniques from triage results + triages = (await db.execute( + select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids)) + )).scalars().all() + for t in triages: + for tech in (t.mitre_techniques or []): + tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech)) + if tid not in seen_techniques: + seen_techniques.add(tid) + objects.append(_technique_to_attack_pattern(tid, identity["id"])) + + # Gather techniques from host profiles + profiles = (await db.execute( + select(HostProfile).where(HostProfile.hunt_id == hunt_id) + )).scalars().all() + for p in profiles: + for tech in (p.mitre_techniques or []): + tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech)) + if tid not in seen_techniques: + seen_techniques.add(tid) + objects.append(_technique_to_attack_pattern(tid, identity["id"])) + + # Gather hypotheses + hypos = (await db.execute( + select(Hypothesis).where(Hypothesis.hunt_id == hunt_id) + )).scalars().all() + for h in hypos: + objects.append(_hypothesis_to_report(h, identity["id"])) + if h.mitre_technique and h.mitre_technique not in seen_techniques: + seen_techniques.add(h.mitre_technique) + objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"])) + + bundle = { + "type": "bundle", + "id": _stix_id("bundle"), + "objects": objects, + } + + filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json" + return Response( + content=json.dumps(bundle, indent=2), + media_type="application/json", + headers={"Content-Disposition": f'attachment; filename="{filename}"'}, + ) diff --git a/backend/app/api/routes/timeline.py b/backend/app/api/routes/timeline.py new file mode 100644 index 0000000..262b7c5 --- /dev/null +++ b/backend/app/api/routes/timeline.py @@ -0,0 +1,128 @@ +"""API routes for forensic timeline visualization.""" + +import logging +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db import get_db +from app.db.models import Dataset, DatasetRow, Hunt + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/timeline", tags=["timeline"]) + + +def _parse_timestamp(val: str | None) -> str | None: + """Try to parse a timestamp string, return ISO format or None.""" + if not val: + return None + val = str(val).strip() + if not val: + return None + # Try common formats + for fmt in [ + "%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", + "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", + "%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S", + ]: + try: + return datetime.strptime(val, fmt).isoformat() + "Z" + except ValueError: + continue + return None + + +# Columns likely to contain timestamps +TIME_COLUMNS = { + "timestamp", "time", "datetime", "date", "created", "modified", + "eventtime", "event_time", "start_time", "end_time", + "lastmodified", "last_modified", "created_at", "updated_at", + "mtime", "atime", "ctime", "btime", + "timecreated", "timegenerated", "sourcetime", +} + + +@router.get("/hunt/{hunt_id}") +async def get_hunt_timeline( + hunt_id: str, + limit: int = 2000, + db: AsyncSession = Depends(get_db), +): + """Build a timeline of events across all datasets in a hunt.""" + result = await db.execute(select(Hunt).where(Hunt.id == hunt_id)) + hunt = result.scalar_one_or_none() + if not hunt: + raise HTTPException(status_code=404, detail="Hunt not found") + + result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id)) + datasets = result.scalars().all() + if not datasets: + return {"hunt_id": hunt_id, "events": [], "datasets": []} + + events = [] + dataset_info = [] + + for ds in datasets: + artifact_type = getattr(ds, "artifact_type", None) or "Unknown" + dataset_info.append({ + "id": ds.id, "name": ds.name, "artifact_type": artifact_type, + "row_count": ds.row_count, + }) + + # Find time columns for this dataset + schema = ds.column_schema or {} + time_cols = [] + for col in (ds.normalized_columns or {}).values(): + if col.lower() in TIME_COLUMNS: + time_cols.append(col) + if not time_cols: + for col in schema: + if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower(): + time_cols.append(col) + if not time_cols: + continue + + # Fetch rows + rows_result = await db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds.id) + .order_by(DatasetRow.row_index) + .limit(limit // max(len(datasets), 1)) + ) + for r in rows_result.scalars().all(): + data = r.normalized_data or r.data + ts = None + for tc in time_cols: + ts = _parse_timestamp(data.get(tc)) + if ts: + break + if ts: + hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or "" + process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or "" + summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or "" + events.append({ + "timestamp": ts, + "dataset_id": ds.id, + "dataset_name": ds.name, + "artifact_type": artifact_type, + "row_index": r.row_index, + "hostname": str(hostname)[:128], + "process": str(process)[:128], + "summary": str(summary)[:256], + "data": {k: str(v)[:100] for k, v in list(data.items())[:8]}, + }) + + # Sort by timestamp + events.sort(key=lambda e: e["timestamp"]) + + return { + "hunt_id": hunt_id, + "hunt_name": hunt.name, + "event_count": len(events), + "datasets": dataset_info, + "events": events[:limit], + } diff --git a/backend/app/config.py b/backend/app/config.py index b4b5e6b..25e2952 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,4 +1,4 @@ -"""Application configuration — single source of truth for all settings. +"""Application configuration - single source of truth for all settings. Loads from environment variables with sensible defaults for local dev. """ @@ -13,12 +13,12 @@ from pydantic import Field class AppConfig(BaseSettings): """Central configuration for the entire ThreatHunt application.""" - # ── General ──────────────────────────────────────────────────────── + # -- General -------------------------------------------------------- APP_NAME: str = "ThreatHunt" APP_VERSION: str = "0.3.0" DEBUG: bool = Field(default=False, description="Enable debug mode") - # ── Database ─────────────────────────────────────────────────────── + # -- Database ------------------------------------------------------- DATABASE_URL: str = Field( default="sqlite+aiosqlite:///./threathunt.db", description="Async SQLAlchemy database URL. " @@ -26,17 +26,17 @@ class AppConfig(BaseSettings): "postgresql+asyncpg://user:pass@host/db for production.", ) - # ── CORS ─────────────────────────────────────────────────────────── + # -- CORS ----------------------------------------------------------- ALLOWED_ORIGINS: str = Field( default="http://localhost:3000,http://localhost:8000", description="Comma-separated list of allowed CORS origins", ) - # ── File uploads ─────────────────────────────────────────────────── + # -- File uploads --------------------------------------------------- MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB") UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files") - # ── LLM Cluster — Wile & Roadrunner ──────────────────────────────── + # -- LLM Cluster - Wile & Roadrunner -------------------------------- OPENWEBUI_URL: str = Field( default="https://ai.guapo613.beer", description="Open WebUI cluster endpoint (OpenAI-compatible API)", @@ -58,7 +58,7 @@ class AppConfig(BaseSettings): default=11434, description="Ollama port on Roadrunner" ) - # ── LLM Routing defaults ────────────────────────────────────────── + # -- LLM Routing defaults ------------------------------------------ DEFAULT_FAST_MODEL: str = Field( default="llama3.1:latest", description="Default model for quick chat / simple queries", @@ -80,18 +80,18 @@ class AppConfig(BaseSettings): description="Default embedding model", ) - # ── Agent behaviour ─────────────────────────────────────────────── + # -- Agent behaviour ------------------------------------------------ AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response") AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance") AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context") FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns") - # ── Enrichment API keys ─────────────────────────────────────────── + # -- Enrichment API keys -------------------------------------------- VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key") ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key") SHODAN_API_KEY: str = Field(default="", description="Shodan API key") - # ── Auth ────────────────────────────────────────────────────────── + # -- Auth ----------------------------------------------------------- JWT_SECRET: str = Field( default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET", description="Secret for JWT signing", @@ -99,6 +99,73 @@ class AppConfig(BaseSettings): JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime") JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime") + # -- Triage settings ------------------------------------------------ + TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch") + TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field( + default=200, description="Stop triage after this many suspicious rows" + ) + TRIAGE_ESCALATION_THRESHOLD: float = Field( + default=5.0, description="Risk score threshold for escalation counting" + ) + + # -- Host profiler settings ----------------------------------------- + HOST_PROFILE_CONCURRENCY: int = Field( + default=3, description="Max concurrent host profile LLM calls" + ) + + # -- Scanner settings ----------------------------------------------- + SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch") + SCANNER_MAX_ROWS_PER_SCAN: int = Field( + default=120000, + description="Global row budget for a single AUP scan request (0 = unlimited)", + ) + + # -- Job queue settings ---------------------------------------------- + JOB_QUEUE_MAX_BACKLOG: int = Field( + default=2000, description="Soft cap for queued background jobs" + ) + JOB_QUEUE_RETAIN_COMPLETED: int = Field( + default=3000, description="Maximum completed/failed jobs to retain in memory" + ) + JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field( + default=60, description="How often to run in-memory job cleanup" + ) + JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field( + default=3600, description="Age threshold for in-memory completed job cleanup" + ) + + # -- Startup throttling ------------------------------------------------ + STARTUP_WARMUP_MAX_HUNTS: int = Field( + default=5, description="Max hunts to warm inventory cache for at startup" + ) + STARTUP_REPROCESS_MAX_DATASETS: int = Field( + default=25, description="Max unprocessed datasets to enqueue at startup" + ) + STARTUP_RECONCILE_STALE_TASKS: bool = Field( + default=True, + description="Mark stale queued/running processing tasks as failed on startup", + ) + + # -- Network API scale guards ----------------------------------------- + NETWORK_SUBGRAPH_MAX_HOSTS: int = Field( + default=400, description="Hard cap for hosts returned by network subgraph endpoint" + ) + NETWORK_SUBGRAPH_MAX_EDGES: int = Field( + default=3000, description="Hard cap for edges returned by network subgraph endpoint" + ) + NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field( + default=5000, + description="Row budget per dataset when building host inventory (0 = unlimited)", + ) + NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field( + default=120000, + description="Global row budget across all datasets for host inventory build (0 = unlimited)", + ) + NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field( + default=120000, + description="Max unique connection tuples retained during host inventory build", + ) + model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"} @property @@ -119,3 +186,4 @@ class AppConfig(BaseSettings): settings = AppConfig() + diff --git a/backend/app/db/engine.py b/backend/app/db/engine.py index dbfc86c..db54e68 100644 --- a/backend/app/db/engine.py +++ b/backend/app/db/engine.py @@ -21,9 +21,14 @@ _engine_kwargs: dict = dict( ) if _is_sqlite: - _engine_kwargs["connect_args"] = {"timeout": 30} - _engine_kwargs["pool_size"] = 1 - _engine_kwargs["max_overflow"] = 0 + _engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False} + # NullPool: each session gets its own connection. + # Combined with WAL mode, this allows concurrent reads while a write is in progress. + from sqlalchemy.pool import NullPool + _engine_kwargs["poolclass"] = NullPool +else: + _engine_kwargs["pool_size"] = 5 + _engine_kwargs["max_overflow"] = 10 engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs) @@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record): if _is_sqlite: cursor = dbapi_conn.cursor() cursor.execute("PRAGMA journal_mode=WAL") - cursor.execute("PRAGMA busy_timeout=5000") + cursor.execute("PRAGMA busy_timeout=30000") cursor.execute("PRAGMA synchronous=NORMAL") cursor.close() @@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker( ) +# Alias expected by other modules +async_session = async_session_factory + + class Base(DeclarativeBase): """Base class for all ORM models.""" pass @@ -71,5 +80,5 @@ async def init_db() -> None: async def dispose_db() -> None: - """Dispose of the engine connection pool.""" - await engine.dispose() \ No newline at end of file + """Dispose of the engine on shutdown.""" + await engine.dispose() diff --git a/backend/app/db/models.py b/backend/app/db/models.py index f2ab0ff..0cc39ec 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -1,4 +1,4 @@ -"""SQLAlchemy ORM models for ThreatHunt. +"""SQLAlchemy ORM models for ThreatHunt. All persistent entities: datasets, hunts, conversations, annotations, hypotheses, enrichment results, users, and AI analysis tables. @@ -43,6 +43,7 @@ class User(Base): hashed_password: Mapped[str] = mapped_column(String(256), nullable=False) role: Mapped[str] = mapped_column(String(16), default="analyst") is_active: Mapped[bool] = mapped_column(Boolean, default=True) + display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin") @@ -399,4 +400,108 @@ class AnomalyResult(Base): cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) is_outlier: Mapped[bool] = mapped_column(Boolean, default=False) explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) \ No newline at end of file + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + + +# -- Persistent Processing Tasks (Phase 2) --- + +class ProcessingTask(Base): + __tablename__ = "processing_tasks" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + hunt_id: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True + ) + dataset_id: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True + ) + job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True) + stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + status: Mapped[str] = mapped_column(String(20), default="queued", index=True) + progress: Mapped[float] = mapped_column(Float, default=0.0) + message: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + __table_args__ = ( + Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"), + Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"), + ) + + +# -- Playbook / Investigation Templates (Feature 3) --- + +class Playbook(Base): + __tablename__ = "playbooks" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + name: Mapped[str] = mapped_column(String(256), nullable=False, index=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + created_by: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("users.id"), nullable=True + ) + is_template: Mapped[bool] = mapped_column(Boolean, default=False) + hunt_id: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("hunts.id"), nullable=True + ) + status: Mapped[str] = mapped_column(String(20), default="active") + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=_utcnow, onupdate=_utcnow + ) + + steps: Mapped[list["PlaybookStep"]] = relationship( + back_populates="playbook", lazy="selectin", cascade="all, delete-orphan", + order_by="PlaybookStep.order_index", + ) + + +class PlaybookStep(Base): + __tablename__ = "playbook_steps" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + playbook_id: Mapped[str] = mapped_column( + String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False + ) + order_index: Mapped[int] = mapped_column(Integer, nullable=False) + title: Mapped[str] = mapped_column(String(256), nullable=False) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + step_type: Mapped[str] = mapped_column(String(32), default="manual") + target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) + is_completed: Mapped[bool] = mapped_column(Boolean, default=False) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + + playbook: Mapped["Playbook"] = relationship(back_populates="steps") + + __table_args__ = ( + Index("ix_playbook_steps_playbook", "playbook_id"), + ) + + +# -- Saved Searches (Feature 5) --- + +class SavedSearch(Base): + __tablename__ = "saved_searches" + + id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id) + name: Mapped[str] = mapped_column(String(256), nullable=False, index=True) + description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + search_type: Mapped[str] = mapped_column(String(32), nullable=False) + query_params: Mapped[dict] = mapped_column(JSON, nullable=False) + threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True) + created_by: Mapped[Optional[str]] = mapped_column( + String(32), ForeignKey("users.id"), nullable=True + ) + last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) + last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow) + + __table_args__ = ( + Index("ix_saved_searches_type", "search_type"), + ) diff --git a/backend/app/main.py b/backend/app/main.py index 6c81ddb..47a08dd 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -25,6 +25,11 @@ from app.api.routes.auth import router as auth_router from app.api.routes.keywords import router as keywords_router from app.api.routes.analysis import router as analysis_router from app.api.routes.network import router as network_router +from app.api.routes.mitre import router as mitre_router +from app.api.routes.timeline import router as timeline_router +from app.api.routes.playbooks import router as playbooks_router +from app.api.routes.saved_searches import router as searches_router +from app.api.routes.stix_export import router as stix_router logger = logging.getLogger(__name__) @@ -47,13 +52,80 @@ async def lifespan(app: FastAPI): await seed_defaults(seed_db) logger.info("AUP keyword defaults checked") - # Start job queue (Phase 10) - from app.services.job_queue import job_queue, register_all_handlers + # Start job queue + from app.services.job_queue import ( + job_queue, + register_all_handlers, + reconcile_stale_processing_tasks, + JobType, + ) + + if settings.STARTUP_RECONCILE_STALE_TASKS: + reconciled = await reconcile_stale_processing_tasks() + if reconciled: + logger.info("Startup reconciliation marked %d stale tasks", reconciled) + register_all_handlers() await job_queue.start() logger.info("Job queue started (%d workers)", job_queue._max_workers) - # Start load balancer health loop (Phase 10) + # Pre-warm host inventory cache for existing hunts + from app.services.host_inventory import inventory_cache + async with async_session_factory() as warm_db: + from sqlalchemy import select, func + from app.db.models import Hunt, Dataset + stmt = ( + select(Hunt.id) + .join(Dataset, Dataset.hunt_id == Hunt.id) + .group_by(Hunt.id) + .having(func.count(Dataset.id) > 0) + ) + result = await warm_db.execute(stmt) + hunt_ids = [row[0] for row in result.all()] + warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS] + for hid in warm_hunts: + job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid) + if warm_hunts: + logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})") + + # Check which datasets still need processing + # (no anomaly results = never fully processed) + async with async_session_factory() as reprocess_db: + from sqlalchemy import select, exists + from app.db.models import Dataset, AnomalyResult + # Find datasets that have zero anomaly results (pipeline never ran or failed) + has_anomaly = ( + select(AnomalyResult.id) + .where(AnomalyResult.dataset_id == Dataset.id) + .limit(1) + .correlate(Dataset) + .exists() + ) + stmt = select(Dataset.id).where(~has_anomaly) + result = await reprocess_db.execute(stmt) + unprocessed_ids = [row[0] for row in result.all()] + + if unprocessed_ids: + to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS] + for ds_id in to_reprocess: + job_queue.submit(JobType.TRIAGE, dataset_id=ds_id) + job_queue.submit(JobType.ANOMALY, dataset_id=ds_id) + job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id) + job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id) + logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})") + async with async_session_factory() as update_db: + from sqlalchemy import update + from app.db.models import Dataset + await update_db.execute( + update(Dataset) + .where(Dataset.id.in_(to_reprocess)) + .values(processing_status="processing") + ) + await update_db.commit() + else: + logger.info("All datasets already processed - skipping startup pipeline") + + # Start load balancer health loop from app.services.load_balancer import lb await lb.start_health_loop(interval=30.0) logger.info("Load balancer health loop started") @@ -61,12 +133,10 @@ async def lifespan(app: FastAPI): yield logger.info("Shutting down ...") - # Stop job queue from app.services.job_queue import job_queue as jq await jq.stop() logger.info("Job queue stopped") - # Stop load balancer from app.services.load_balancer import lb as _lb await _lb.stop_health_loop() logger.info("Load balancer stopped") @@ -106,6 +176,11 @@ app.include_router(reports_router) app.include_router(keywords_router) app.include_router(analysis_router) app.include_router(network_router) +app.include_router(mitre_router) +app.include_router(timeline_router) +app.include_router(playbooks_router) +app.include_router(searches_router) +app.include_router(stix_router) @app.get("/", tags=["health"]) @@ -120,4 +195,4 @@ async def root(): "roadrunner": settings.roadrunner_url, "openwebui": settings.OPENWEBUI_URL, }, - } \ No newline at end of file + } diff --git a/backend/app/services/host_inventory.py b/backend/app/services/host_inventory.py index 32f0146..b662c5a 100644 --- a/backend/app/services/host_inventory.py +++ b/backend/app/services/host_inventory.py @@ -13,6 +13,7 @@ from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import Dataset, DatasetRow +from app.config import settings logger = logging.getLogger(__name__) @@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str: return name or '' + + +# In-memory host inventory cache +# Pre-computed results stored per hunt_id, built in background after upload. + +import time as _time + +class _InventoryCache: + """Simple in-memory cache for pre-computed host inventories.""" + + def __init__(self): + self._data: dict[str, dict] = {} # hunt_id -> result dict + self._timestamps: dict[str, float] = {} # hunt_id -> epoch + self._building: set[str] = set() # hunt_ids currently being built + + def get(self, hunt_id: str) -> dict | None: + """Return cached result if present. Never expires; only invalidated on new upload.""" + return self._data.get(hunt_id) + + def put(self, hunt_id: str, result: dict): + self._data[hunt_id] = result + self._timestamps[hunt_id] = _time.time() + self._building.discard(hunt_id) + logger.info(f"Cached host inventory for hunt {hunt_id} " + f"({result['stats']['total_hosts']} hosts)") + + def invalidate(self, hunt_id: str): + self._data.pop(hunt_id, None) + self._timestamps.pop(hunt_id, None) + + def is_building(self, hunt_id: str) -> bool: + return hunt_id in self._building + + def set_building(self, hunt_id: str): + self._building.add(hunt_id) + + def clear_building(self, hunt_id: str): + self._building.discard(hunt_id) + + def status(self, hunt_id: str) -> str: + if hunt_id in self._building: + return "building" + if hunt_id in self._data: + return "ready" + return "none" + + +inventory_cache = _InventoryCache() + def _infer_os(fqdn: str) -> str: u = fqdn.upper() if 'W10-' in u or 'WIN10' in u: @@ -151,33 +201,61 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: }} hosts: dict[str, dict] = {} # fqdn -> host record - ip_to_host: dict[str, str] = {} # local-ip -> fqdn + ip_to_host: dict[str, str] = {} # local-ip -> fqdn connections: dict[tuple, int] = defaultdict(int) total_rows = 0 ds_with_hosts = 0 + sampled_dataset_count = 0 + total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS)) + max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS)) + global_budget_reached = False + dropped_connections = 0 for ds in all_datasets: + if total_row_budget and total_rows >= total_row_budget: + global_budget_reached = True + break + cols = _identify_columns(ds) if not cols['fqdn'] and not cols['host_id']: continue ds_with_hosts += 1 batch_size = 5000 - offset = 0 + max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET)) + rows_scanned_this_dataset = 0 + sampled_dataset = False + last_row_index = -1 + while True: + if total_row_budget and total_rows >= total_row_budget: + sampled_dataset = True + global_budget_reached = True + break + rr = await db.execute( select(DatasetRow) .where(DatasetRow.dataset_id == ds.id) + .where(DatasetRow.row_index > last_row_index) .order_by(DatasetRow.row_index) - .offset(offset).limit(batch_size) + .limit(batch_size) ) rows = rr.scalars().all() if not rows: break for ro in rows: + if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset: + sampled_dataset = True + break + if total_row_budget and total_rows >= total_row_budget: + sampled_dataset = True + global_budget_reached = True + break + data = ro.data or {} total_rows += 1 + rows_scanned_this_dataset += 1 fqdn = '' for c in cols['fqdn']: @@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: rport = _clean(data.get(pc)) if rport: break - connections[(host_key, rip, rport)] += 1 + conn_key = (host_key, rip, rport) + if max_connections and len(connections) >= max_connections and conn_key not in connections: + dropped_connections += 1 + continue + connections[conn_key] += 1 - offset += batch_size + if sampled_dataset: + sampled_dataset_count += 1 + logger.info( + "Host inventory sampling for dataset %s (%d rows scanned)", + ds.id, + rows_scanned_this_dataset, + ) + break + + last_row_index = rows[-1].row_index if len(rows) < batch_size: break + if global_budget_reached: + logger.info( + "Host inventory global row budget reached for hunt %s at %d rows", + hunt_id, + total_rows, + ) + break + # Post-process hosts for h in hosts.values(): if not h['os'] and h['fqdn']: @@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict: "total_rows_scanned": total_rows, "hosts_with_ips": sum(1 for h in host_list if h['ips']), "hosts_with_users": sum(1 for h in host_list if h['users']), + "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET, + "row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS, + "connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS, + "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0, + "sampled_datasets": sampled_dataset_count, + "global_budget_reached": global_budget_reached, + "dropped_connections": dropped_connections, }, - } \ No newline at end of file + } diff --git a/backend/app/services/host_profiler.py b/backend/app/services/host_profiler.py index c4f8bb1..c40e172 100644 --- a/backend/app/services/host_profiler.py +++ b/backend/app/services/host_profiler.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import re import json import logging @@ -18,6 +19,9 @@ logger = logging.getLogger(__name__) HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL WILE_URL = f"{settings.wile_url}/api/generate" +# Velociraptor client IDs (C.hex) are not real hostnames +CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$") + async def _get_triage_summary(db, dataset_id: str) -> str: result = await db.execute( @@ -154,7 +158,7 @@ async def profile_host( logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level) except Exception as e: - logger.error("Failed to profile host %s: %s", hostname, e) + logger.error("Failed to profile host %s: %r", hostname, e) profile = HostProfile( hunt_id=hunt_id, hostname=hostname, fqdn=fqdn, risk_score=0.0, risk_level="unknown", @@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None: if h not in hostnames: hostnames[h] = data.get("fqdn") or data.get("Fqdn") + # Filter out Velociraptor client IDs - not real hostnames + real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)} + skipped = len(hostnames) - len(real_hosts) + if skipped: + logger.info("Skipped %d Velociraptor client IDs", skipped) + hostnames = real_hosts + logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id) semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY) diff --git a/backend/app/services/job_queue.py b/backend/app/services/job_queue.py index f218cdf..591d5db 100644 --- a/backend/app/services/job_queue.py +++ b/backend/app/services/job_queue.py @@ -1,8 +1,8 @@ """Async job queue for background AI tasks. Manages triage, profiling, report generation, anomaly detection, -and data queries as trackable jobs with status, progress, and -cancellation support. +keyword scanning, IOC extraction, and data queries as trackable +jobs with status, progress, and cancellation support. """ from __future__ import annotations @@ -15,6 +15,8 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Coroutine, Optional +from app.config import settings + logger = logging.getLogger(__name__) @@ -32,6 +34,18 @@ class JobType(str, Enum): REPORT = "report" ANOMALY = "anomaly" QUERY = "query" + HOST_INVENTORY = "host_inventory" + KEYWORD_SCAN = "keyword_scan" + IOC_EXTRACT = "ioc_extract" + + +# Job types that form the automatic upload pipeline +PIPELINE_JOB_TYPES = frozenset({ + JobType.TRIAGE, + JobType.ANOMALY, + JobType.KEYWORD_SCAN, + JobType.IOC_EXTRACT, +}) @dataclass @@ -82,11 +96,7 @@ class Job: class JobQueue: - """In-memory async job queue with concurrency control. - - Jobs are tracked by ID and can be listed, polled, or cancelled. - A configurable number of workers process jobs from the queue. - """ + """In-memory async job queue with concurrency control.""" def __init__(self, max_workers: int = 3): self._jobs: dict[str, Job] = {} @@ -95,47 +105,56 @@ class JobQueue: self._workers: list[asyncio.Task] = [] self._handlers: dict[JobType, Callable] = {} self._started = False + self._completion_callbacks: list[Callable[[Job], Coroutine]] = [] + self._cleanup_task: asyncio.Task | None = None - def register_handler( - self, - job_type: JobType, - handler: Callable[[Job], Coroutine], - ): - """Register an async handler for a job type. - - Handler signature: async def handler(job: Job) -> Any - The handler can update job.progress and job.message during execution. - It should check job.is_cancelled periodically and return early. - """ + def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]): self._handlers[job_type] = handler logger.info(f"Registered handler for {job_type.value}") + def on_completion(self, callback: Callable[[Job], Coroutine]): + """Register a callback invoked after any job completes or fails.""" + self._completion_callbacks.append(callback) + async def start(self): - """Start worker tasks.""" if self._started: return self._started = True for i in range(self._max_workers): task = asyncio.create_task(self._worker(i)) self._workers.append(task) + if not self._cleanup_task or self._cleanup_task.done(): + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) logger.info(f"Job queue started with {self._max_workers} workers") async def stop(self): - """Stop all workers.""" self._started = False for w in self._workers: w.cancel() await asyncio.gather(*self._workers, return_exceptions=True) self._workers.clear() + if self._cleanup_task: + self._cleanup_task.cancel() + await asyncio.gather(self._cleanup_task, return_exceptions=True) + self._cleanup_task = None logger.info("Job queue stopped") def submit(self, job_type: JobType, **params) -> Job: - """Submit a new job. Returns the Job object immediately.""" - job = Job( - id=str(uuid.uuid4()), - job_type=job_type, - params=params, - ) + # Soft backpressure: prefer dedupe over queue amplification + dedupe_job = self._find_active_duplicate(job_type, params) + if dedupe_job is not None: + logger.info( + f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}" + ) + return dedupe_job + + if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG: + logger.warning( + "Job queue backlog high (%d >= %d). Accepting job but system may be degraded.", + self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG, + ) + + job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params) self._jobs[job.id] = job self._queue.put_nowait(job.id) logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}") @@ -144,6 +163,22 @@ class JobQueue: def get_job(self, job_id: str) -> Job | None: return self._jobs.get(job_id) + def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None: + """Return queued/running job with same key workload to prevent duplicate storms.""" + key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"] + sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None) + if not sig: + return None + for j in self._jobs.values(): + if j.job_type != job_type: + continue + if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING): + continue + other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None) + if sig == other_sig: + return j + return None + def cancel_job(self, job_id: str) -> bool: job = self._jobs.get(job_id) if not job: @@ -153,13 +188,7 @@ class JobQueue: job.cancel() return True - def list_jobs( - self, - status: JobStatus | None = None, - job_type: JobType | None = None, - limit: int = 50, - ) -> list[dict]: - """List jobs, newest first.""" + def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]: jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True) if status: jobs = [j for j in jobs if j.status == status] @@ -168,7 +197,6 @@ class JobQueue: return [j.to_dict() for j in jobs[:limit]] def get_stats(self) -> dict: - """Get queue statistics.""" by_status = {} for j in self._jobs.values(): by_status[j.status.value] = by_status.get(j.status.value, 0) + 1 @@ -177,26 +205,58 @@ class JobQueue: "queued": self._queue.qsize(), "by_status": by_status, "workers": self._max_workers, - "active_workers": sum( - 1 for j in self._jobs.values() if j.status == JobStatus.RUNNING - ), + "active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING), } + def is_backlogged(self) -> bool: + return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG + + def can_accept(self, reserve: int = 0) -> bool: + return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG + def cleanup(self, max_age_seconds: float = 3600): - """Remove old completed/failed/cancelled jobs.""" now = time.time() + terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) to_remove = [ jid for jid, j in self._jobs.items() - if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) - and (now - j.created_at) > max_age_seconds + if j.status in terminal_states and (now - j.created_at) > max_age_seconds + ] + + # Also cap retained terminal jobs to avoid unbounded memory growth + terminal_jobs = sorted( + [j for j in self._jobs.values() if j.status in terminal_states], + key=lambda j: j.created_at, + reverse=True, + ) + overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :] + to_remove.extend([j.id for j in overflow]) + + removed = 0 + for jid in set(to_remove): + if jid in self._jobs: + del self._jobs[jid] + removed += 1 + if removed: + logger.info(f"Cleaned up {removed} old jobs") + + async def _cleanup_loop(self): + interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS) + while self._started: + try: + self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS) + except Exception as e: + logger.warning(f"Job queue cleanup loop error: {e}") + await asyncio.sleep(interval) + + def find_pipeline_jobs(self, dataset_id: str) -> list[Job]: + """Find all pipeline jobs for a given dataset_id.""" + return [ + j for j in self._jobs.values() + if j.job_type in PIPELINE_JOB_TYPES + and j.params.get("dataset_id") == dataset_id ] - for jid in to_remove: - del self._jobs[jid] - if to_remove: - logger.info(f"Cleaned up {len(to_remove)} old jobs") async def _worker(self, worker_id: int): - """Worker loop: pull jobs from queue and execute handlers.""" logger.info(f"Worker {worker_id} started") while self._started: try: @@ -220,7 +280,10 @@ class JobQueue: job.status = JobStatus.RUNNING job.started_at = time.time() + if job.progress <= 0: + job.progress = 5.0 job.message = "Running..." + await _sync_processing_task(job) logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})") try: @@ -231,38 +294,111 @@ class JobQueue: job.result = result job.message = "Completed" job.completed_at = time.time() - logger.info( - f"Worker {worker_id}: completed {job.id} " - f"in {job.elapsed_ms}ms" - ) + logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms") except Exception as e: if not job.is_cancelled: job.status = JobStatus.FAILED job.error = str(e) job.message = f"Failed: {e}" job.completed_at = time.time() - logger.error( - f"Worker {worker_id}: failed {job.id}: {e}", - exc_info=True, - ) + logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True) + + if job.is_cancelled and not job.completed_at: + job.completed_at = time.time() + + await _sync_processing_task(job) + + # Fire completion callbacks + for cb in self._completion_callbacks: + try: + await cb(job) + except Exception as cb_err: + logger.error(f"Completion callback error: {cb_err}", exc_info=True) -# Singleton + job handlers +async def _sync_processing_task(job: Job): + """Persist latest job state into processing_tasks (if linked by job_id).""" + from datetime import datetime, timezone + from sqlalchemy import update -job_queue = JobQueue(max_workers=3) + try: + from app.db import async_session_factory + from app.db.models import ProcessingTask + + values = { + "status": job.status.value, + "progress": float(job.progress), + "message": job.message, + "error": job.error, + } + if job.started_at: + values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc) + if job.completed_at: + values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc) + + async with async_session_factory() as db: + await db.execute( + update(ProcessingTask) + .where(ProcessingTask.job_id == job.id) + .values(**values) + ) + await db.commit() + except Exception as e: + logger.warning(f"Failed to sync processing task for job {job.id}: {e}") + + +# -- Singleton + job handlers -- + +job_queue = JobQueue(max_workers=5) async def _handle_triage(job: Job): - """Triage handler.""" + """Triage handler - chains HOST_PROFILE after completion.""" from app.services.triage import triage_dataset dataset_id = job.params.get("dataset_id") job.message = f"Triaging dataset {dataset_id}" - results = await triage_dataset(dataset_id) - return {"count": len(results) if results else 0} + await triage_dataset(dataset_id) + + # Chain: trigger host profiling now that triage results exist + from app.db import async_session_factory + from app.db.models import Dataset + from sqlalchemy import select + try: + async with async_session_factory() as db: + ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id)) + row = ds.first() + hunt_id = row[0] if row else None + if hunt_id: + hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id) + try: + from sqlalchemy import select + from app.db.models import ProcessingTask + async with async_session_factory() as db: + existing = await db.execute( + select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id) + ) + if existing.first() is None: + db.add(ProcessingTask( + hunt_id=hunt_id, + dataset_id=dataset_id, + job_id=hp_job.id, + stage="host_profile", + status="queued", + progress=0.0, + message="Queued", + )) + await db.commit() + except Exception as persist_err: + logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}") + + logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}") + except Exception as e: + logger.warning(f"Failed to chain host profile after triage: {e}") + + return {"dataset_id": dataset_id} async def _handle_host_profile(job: Job): - """Host profiling handler.""" from app.services.host_profiler import profile_all_hosts, profile_host hunt_id = job.params.get("hunt_id") hostname = job.params.get("hostname") @@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job): async def _handle_report(job: Job): - """Report generation handler.""" from app.services.report_generator import generate_report hunt_id = job.params.get("hunt_id") job.message = f"Generating report for hunt {hunt_id}" @@ -286,7 +421,6 @@ async def _handle_report(job: Job): async def _handle_anomaly(job: Job): - """Anomaly detection handler.""" from app.services.anomaly_detector import detect_anomalies dataset_id = job.params.get("dataset_id") k = job.params.get("k", 3) @@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job): async def _handle_query(job: Job): - """Data query handler (non-streaming).""" from app.services.data_query import query_dataset dataset_id = job.params.get("dataset_id") question = job.params.get("question", "") @@ -307,10 +440,152 @@ async def _handle_query(job: Job): return {"answer": answer} +async def _handle_host_inventory(job: Job): + from app.db import async_session_factory + from app.services.host_inventory import build_host_inventory, inventory_cache + + hunt_id = job.params.get("hunt_id") + if not hunt_id: + raise ValueError("hunt_id required") + + inventory_cache.set_building(hunt_id) + job.message = f"Building host inventory for hunt {hunt_id}" + + try: + async with async_session_factory() as db: + result = await build_host_inventory(hunt_id, db) + inventory_cache.put(hunt_id, result) + job.message = f"Built inventory: {result['stats']['total_hosts']} hosts" + return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]} + except Exception: + inventory_cache.clear_building(hunt_id) + raise + + +async def _handle_keyword_scan(job: Job): + """AUP keyword scan handler.""" + from app.db import async_session_factory + from app.services.scanner import KeywordScanner, keyword_scan_cache + + dataset_id = job.params.get("dataset_id") + job.message = f"Running AUP keyword scan on dataset {dataset_id}" + + async with async_session_factory() as db: + scanner = KeywordScanner(db) + result = await scanner.scan(dataset_ids=[dataset_id]) + + # Cache dataset-only result for fast API reuse + if dataset_id: + keyword_scan_cache.put(dataset_id, result) + + hits = result.get("total_hits", 0) + job.message = f"Keyword scan complete: {hits} hits" + logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows") + return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)} + + +async def _handle_ioc_extract(job: Job): + """IOC extraction handler.""" + from app.db import async_session_factory + from app.services.ioc_extractor import extract_iocs_from_dataset + + dataset_id = job.params.get("dataset_id") + job.message = f"Extracting IOCs from dataset {dataset_id}" + + async with async_session_factory() as db: + iocs = await extract_iocs_from_dataset(dataset_id, db) + + total = sum(len(v) for v in iocs.values()) + job.message = f"IOC extraction complete: {total} IOCs found" + logger.info(f"IOC extract for {dataset_id}: {total} IOCs") + return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}} + + +async def _on_pipeline_job_complete(job: Job): + """Update Dataset.processing_status when all pipeline jobs finish.""" + if job.job_type not in PIPELINE_JOB_TYPES: + return + + dataset_id = job.params.get("dataset_id") + if not dataset_id: + return + + pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id) + if not pipeline_jobs: + return + + all_done = all( + j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED) + for j in pipeline_jobs + ) + if not all_done: + return + + any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs) + new_status = "completed_with_errors" if any_failed else "completed" + + try: + from app.db import async_session_factory + from app.db.models import Dataset + from sqlalchemy import update + + async with async_session_factory() as db: + await db.execute( + update(Dataset) + .where(Dataset.id == dataset_id) + .values(processing_status=new_status) + ) + await db.commit() + logger.info(f"Dataset {dataset_id} processing_status -> {new_status}") + except Exception as e: + logger.error(f"Failed to update processing_status for {dataset_id}: {e}") + + + + +async def reconcile_stale_processing_tasks() -> int: + """Mark queued/running processing tasks from prior runs as failed.""" + from datetime import datetime, timezone + from sqlalchemy import update + + try: + from app.db import async_session_factory + from app.db.models import ProcessingTask + + now = datetime.now(timezone.utc) + async with async_session_factory() as db: + result = await db.execute( + update(ProcessingTask) + .where(ProcessingTask.status.in_(["queued", "running"])) + .values( + status="failed", + error="Recovered after service restart before task completion", + message="Recovered stale task after restart", + completed_at=now, + ) + ) + await db.commit() + updated = int(result.rowcount or 0) + + if updated: + logger.warning( + "Reconciled %d stale processing tasks (queued/running -> failed) during startup", + updated, + ) + return updated + except Exception as e: + logger.warning(f"Failed to reconcile stale processing tasks: {e}") + return 0 + + def register_all_handlers(): - """Register all job handlers.""" + """Register all job handlers and completion callbacks.""" job_queue.register_handler(JobType.TRIAGE, _handle_triage) job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile) job_queue.register_handler(JobType.REPORT, _handle_report) job_queue.register_handler(JobType.ANOMALY, _handle_anomaly) - job_queue.register_handler(JobType.QUERY, _handle_query) \ No newline at end of file + job_queue.register_handler(JobType.QUERY, _handle_query) + job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory) + job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan) + job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract) + job_queue.on_completion(_on_pipeline_job_complete) diff --git a/backend/app/services/scanner.py b/backend/app/services/scanner.py index df910e8..c18e009 100644 --- a/backend/app/services/scanner.py +++ b/backend/app/services/scanner.py @@ -1,4 +1,4 @@ -"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and +"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and messages for keyword matches. Scanning is done in Python (not SQL LIKE on JSON columns) for portability @@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context. import logging import re from dataclasses import dataclass, field +from datetime import datetime, timezone -from sqlalchemy import select, func +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import settings + from app.db.models import ( KeywordTheme, - Keyword, DatasetRow, Dataset, Hunt, Annotation, Message, - Conversation, ) logger = logging.getLogger(__name__) -BATCH_SIZE = 500 +BATCH_SIZE = 200 + + +def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]: + """Best-effort extraction of hostname and user from a dataset row.""" + if not data: + return None, None + + host_keys = ( + 'hostname', 'host_name', 'host', 'computer_name', 'computer', + 'fqdn', 'client_id', 'agent_id', 'endpoint_id', + ) + user_keys = ( + 'username', 'user_name', 'user', 'account_name', + 'logged_in_user', 'samaccountname', 'sam_account_name', + ) + + def pick(keys): + for k in keys: + for actual_key, v in data.items(): + if actual_key.lower() == k and v not in (None, ''): + return str(v) + return None + + return pick(host_keys), pick(user_keys) @dataclass @@ -39,6 +64,8 @@ class ScanHit: matched_value: str row_index: int | None = None dataset_name: str | None = None + hostname: str | None = None + username: str | None = None @dataclass @@ -50,21 +77,54 @@ class ScanResult: rows_scanned: int = 0 +@dataclass +class KeywordScanCacheEntry: + dataset_id: str + result: dict + built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + +class KeywordScanCache: + """In-memory per-dataset cache for dataset-only keyword scans. + + This enables fast-path reads when users run AUP scans against datasets that + were already scanned during upload pipeline processing. + """ + + def __init__(self): + self._entries: dict[str, KeywordScanCacheEntry] = {} + + def put(self, dataset_id: str, result: dict): + self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result) + + def get(self, dataset_id: str) -> KeywordScanCacheEntry | None: + return self._entries.get(dataset_id) + + def invalidate_dataset(self, dataset_id: str): + self._entries.pop(dataset_id, None) + + def clear(self): + self._entries.clear() + + +keyword_scan_cache = KeywordScanCache() + + class KeywordScanner: """Scans multiple data sources for keyword/regex matches.""" def __init__(self, db: AsyncSession): self.db = db - # ── Public API ──────────────────────────────────────────────────── + # Public API async def scan( self, dataset_ids: list[str] | None = None, theme_ids: list[str] | None = None, - scan_hunts: bool = True, - scan_annotations: bool = True, - scan_messages: bool = True, + scan_hunts: bool = False, + scan_annotations: bool = False, + scan_messages: bool = False, ) -> dict: """Run a full AUP scan and return dict matching ScanResponse.""" # Load themes + keywords @@ -103,7 +163,7 @@ class KeywordScanner: "rows_scanned": result.rows_scanned, } - # ── Internal ────────────────────────────────────────────────────── + # Internal async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]: q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712 @@ -143,6 +203,8 @@ class KeywordScanner: hits: list[ScanHit], row_index: int | None = None, dataset_name: str | None = None, + hostname: str | None = None, + username: str | None = None, ) -> None: """Check text against all compiled patterns, append hits.""" if not text: @@ -150,8 +212,7 @@ class KeywordScanner: for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items(): for kw_value, pat in keyword_patterns: if pat.search(text): - # Truncate matched_value for display - matched_preview = text[:200] + ("…" if len(text) > 200 else "") + matched_preview = text[:200] + ("" if len(text) > 200 else "") hits.append(ScanHit( theme_name=theme_name, theme_color=theme_color, @@ -162,13 +223,14 @@ class KeywordScanner: matched_value=matched_preview, row_index=row_index, dataset_name=dataset_name, + hostname=hostname, + username=username, )) async def _scan_datasets( self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None ) -> None: - """Scan dataset rows in batches.""" - # Build dataset name lookup + """Scan dataset rows in batches using keyset pagination (no OFFSET).""" ds_q = select(Dataset.id, Dataset.name) if dataset_ids: ds_q = ds_q.where(Dataset.id.in_(dataset_ids)) @@ -178,37 +240,66 @@ class KeywordScanner: if not ds_map: return - # Iterate rows in batches - offset = 0 - row_q_base = select(DatasetRow).where( - DatasetRow.dataset_id.in_(list(ds_map.keys())) - ).order_by(DatasetRow.id) + import asyncio - while True: - rows_result = await self.db.execute( - row_q_base.offset(offset).limit(BATCH_SIZE) + max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN)) + budget_reached = False + + for ds_id, ds_name in ds_map.items(): + if max_rows and result.rows_scanned >= max_rows: + budget_reached = True + break + + last_id = 0 + while True: + if max_rows and result.rows_scanned >= max_rows: + budget_reached = True + break + rows_result = await self.db.execute( + select(DatasetRow) + .where(DatasetRow.dataset_id == ds_id) + .where(DatasetRow.id > last_id) + .order_by(DatasetRow.id) + .limit(BATCH_SIZE) + ) + rows = rows_result.scalars().all() + if not rows: + break + + for row in rows: + result.rows_scanned += 1 + data = row.data or {} + hostname, username = _infer_hostname_and_user(data) + for col_name, cell_value in data.items(): + if cell_value is None: + continue + text = str(cell_value) + self._match_text( + text, + patterns, + "dataset_row", + row.id, + col_name, + result.hits, + row_index=row.row_index, + dataset_name=ds_name, + hostname=hostname, + username=username, + ) + + last_id = rows[-1].id + await asyncio.sleep(0) + if len(rows) < BATCH_SIZE: + break + + if budget_reached: + break + + if budget_reached: + logger.warning( + "AUP scan row budget reached (%d rows). Returning partial results.", + result.rows_scanned, ) - rows = rows_result.scalars().all() - if not rows: - break - - for row in rows: - result.rows_scanned += 1 - data = row.data or {} - for col_name, cell_value in data.items(): - if cell_value is None: - continue - text = str(cell_value) - self._match_text( - text, patterns, "dataset_row", row.id, - col_name, result.hits, - row_index=row.row_index, - dataset_name=ds_map.get(row.dataset_id), - ) - - offset += BATCH_SIZE - if len(rows) < BATCH_SIZE: - break async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None: """Scan hunt names and descriptions.""" diff --git a/backend/app/services/triage.py b/backend/app/services/triage.py index 09be225..3f65f85 100644 --- a/backend/app/services/triage.py +++ b/backend/app/services/triage.py @@ -1,4 +1,4 @@ -"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner.""" +"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner.""" from __future__ import annotations @@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult logger = logging.getLogger(__name__) -DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M" +DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate" ARTIFACT_FOCUS = { @@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None: rows_result = await db.execute( select(DatasetRow) .where(DatasetRow.dataset_id == dataset_id) - .order_by(DatasetRow.row_number) + .order_by(DatasetRow.row_index) .offset(offset) .limit(batch_size) ) @@ -167,4 +167,4 @@ Be precise. Only flag genuinely suspicious items. Respond with valid JSON only." offset += batch_size - logger.info("Triage complete for dataset %s", dataset_id) \ No newline at end of file + logger.info("Triage complete for dataset %s", dataset_id) diff --git a/backend/tests/test_agent_policy_execution.py b/backend/tests/test_agent_policy_execution.py new file mode 100644 index 0000000..374b93e --- /dev/null +++ b/backend/tests/test_agent_policy_execution.py @@ -0,0 +1,124 @@ +"""Tests for execution-mode behavior in /api/agent/assist.""" + +import io + +import pytest + + +@pytest.mark.asyncio +async def test_agent_assist_policy_query_executes_scan(client): + # 1) Create hunt + h = await client.post("/api/hunts", json={"name": "Policy Hunt"}) + assert h.status_code == 200 + hunt_id = h.json()["id"] + + # 2) Upload browser-history-like CSV + csv_bytes = ( + b"User,visited_url,title,ClientId,Fqdn\n" + b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n" + b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n" + ) + files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")} + up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files) + assert up.status_code == 200 + + # 3) Ensure policy theme/keyword exists + t = await client.post( + "/api/keywords/themes", + json={ + "name": "Adult Content", + "color": "#e91e63", + "enabled": True, + }, + ) + assert t.status_code in (201, 409) + + themes = await client.get("/api/keywords/themes") + assert themes.status_code == 200 + adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content") + + k = await client.post( + f"/api/keywords/themes/{adult['id']}/keywords", + json={"value": "pornhub", "is_regex": False}, + ) + assert k.status_code in (201, 409) + + # 4) Execution-mode query + q = await client.post( + "/api/agent/assist", + json={ + "query": "Analyze browser history for policy-violating domains and summarize by user and host.", + "hunt_id": hunt_id, + }, + ) + assert q.status_code == 200 + body = q.json() + + assert body["model_used"] == "execution:keyword_scanner" + assert body["execution"] is not None + assert body["execution"]["policy_hits"] >= 1 + assert len(body["execution"]["top_user_hosts"]) >= 1 + + +@pytest.mark.asyncio +async def test_agent_assist_execution_preference_off_stays_advisory(client): + h = await client.post("/api/hunts", json={"name": "No Exec Hunt"}) + assert h.status_code == 200 + hunt_id = h.json()["id"] + + q = await client.post( + "/api/agent/assist", + json={ + "query": "Analyze browser history for policy-violating domains and summarize by user and host.", + "hunt_id": hunt_id, + "execution_preference": "off", + }, + ) + assert q.status_code == 200 + body = q.json() + assert body["model_used"] != "execution:keyword_scanner" + assert body["execution"] is None + + +@pytest.mark.asyncio +async def test_agent_assist_execution_preference_force_executes(client): + # Create hunt + dataset even when the query text is not policy-specific + h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"}) + assert h.status_code == 200 + hunt_id = h.json()["id"] + + csv_bytes = ( + b"User,visited_url,title,ClientId,Fqdn\n" + b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n" + ) + files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")} + up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files) + assert up.status_code == 200 + + t = await client.post( + "/api/keywords/themes", + json={"name": "Adult Content", "color": "#e91e63", "enabled": True}, + ) + assert t.status_code in (201, 409) + + themes = await client.get("/api/keywords/themes") + assert themes.status_code == 200 + adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content") + k = await client.post( + f"/api/keywords/themes/{adult['id']}/keywords", + json={"value": "pornhub", "is_regex": False}, + ) + assert k.status_code in (201, 409) + + q = await client.post( + "/api/agent/assist", + json={ + "query": "Summarize notable activity in this hunt.", + "hunt_id": hunt_id, + "execution_preference": "force", + }, + ) + assert q.status_code == 200 + body = q.json() + assert body["model_used"] == "execution:keyword_scanner" + assert body["execution"] is not None diff --git a/backend/tests/test_api.py b/backend/tests/test_api.py index d1e634d..4b56a4f 100644 --- a/backend/tests/test_api.py +++ b/backend/tests/test_api.py @@ -77,6 +77,26 @@ class TestHuntEndpoints: assert resp.status_code == 404 + async def test_hunt_progress(self, client): + create = await client.post("/api/hunts", json={"name": "Progress Hunt"}) + hunt_id = create.json()["id"] + + # attach one dataset so progress has scope + from tests.conftest import SAMPLE_CSV + import io + files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")} + up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files) + assert up.status_code == 200 + + res = await client.get(f"/api/hunts/{hunt_id}/progress") + assert res.status_code == 200 + body = res.json() + assert body["hunt_id"] == hunt_id + assert "progress_percent" in body + assert "dataset_total" in body + assert "network_status" in body + + @pytest.mark.asyncio class TestDatasetEndpoints: """Test dataset upload and retrieval.""" diff --git a/backend/tests/test_csv_parser.py b/backend/tests/test_csv_parser.py index 21f1746..2f3916b 100644 --- a/backend/tests/test_csv_parser.py +++ b/backend/tests/test_csv_parser.py @@ -1,4 +1,4 @@ -"""Tests for CSV parser and normalizer services.""" +"""Tests for CSV parser and normalizer services.""" import pytest from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types @@ -43,8 +43,9 @@ class TestCSVParser: assert len(rows) == 2 def test_parse_empty_file(self): - with pytest.raises(Exception): - parse_csv_bytes(b"") + rows, meta = parse_csv_bytes(b"") + assert len(rows) == 0 + assert meta["row_count"] == 0 def test_detect_encoding_utf8(self): enc = detect_encoding(SAMPLE_CSV) @@ -53,17 +54,15 @@ class TestCSVParser: def test_infer_column_types(self): types = infer_column_types( - ["192.168.1.1", "10.0.0.1", "8.8.8.8"], - "src_ip", + [{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}], ) - assert types == "ip" + assert types["src_ip"] == "ip" def test_infer_column_types_hash(self): types = infer_column_types( - ["d41d8cd98f00b204e9800998ecf8427e"], - "hash", + [{"hash": "d41d8cd98f00b204e9800998ecf8427e"}], ) - assert types == "hash_md5" + assert types["hash"] == "hash_md5" class TestNormalizer: @@ -94,7 +93,7 @@ class TestNormalizer: start, end = detect_time_range(rows, column_mapping) # Should detect time range from timestamp column if start: - assert "2025" in start + assert "2025" in str(start) def test_normalize_rows(self): rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}] @@ -102,3 +101,6 @@ class TestNormalizer: normalized = normalize_rows(rows, mapping) assert len(normalized) == 1 assert normalized[0].get("src_ip") == "10.0.0.1" + + + diff --git a/backend/tests/test_keywords.py b/backend/tests/test_keywords.py index 234a7c1..b647c53 100644 --- a/backend/tests/test_keywords.py +++ b/backend/tests/test_keywords.py @@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient): assert "total_hits" in data # powershell should match at least one row assert data["total_hits"] > 0 + + +@pytest.mark.asyncio +async def test_quick_scan_cache_hit(client: AsyncClient): + """Second quick scan should return cache hit metadata.""" + theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"}) + tid = theme_res.json()["id"] + await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"}) + + from tests.conftest import SAMPLE_CSV + import io + files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")} + upload = await client.post("/api/datasets/upload", files=files) + ds_id = upload.json()["id"] + + first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}") + assert first.status_code == 200 + assert first.json().get("cache_status") in ("miss", "hit") + + second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}") + assert second.status_code == 200 + body = second.json() + assert body.get("cache_used") is True + assert body.get("cache_status") == "hit" diff --git a/backend/tests/test_network.py b/backend/tests/test_network.py new file mode 100644 index 0000000..415c318 --- /dev/null +++ b/backend/tests/test_network.py @@ -0,0 +1,84 @@ +"""Tests for network inventory endpoints and cache/polling behavior.""" + +import io + +import pytest + +from app.services.host_inventory import inventory_cache +from tests.conftest import SAMPLE_CSV + + +@pytest.mark.asyncio +async def test_inventory_status_none_for_unknown_hunt(client): + hunt_id = "hunt-does-not-exist" + inventory_cache.invalidate(hunt_id) + inventory_cache.clear_building(hunt_id) + + res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}") + assert res.status_code == 200 + body = res.json() + assert body["hunt_id"] == hunt_id + assert body["status"] == "none" + + +@pytest.mark.asyncio +async def test_host_inventory_cold_cache_returns_202(client): + # Create hunt and upload dataset linked to that hunt + hunt = await client.post("/api/hunts", json={"name": "Net Hunt"}) + hunt_id = hunt.json()["id"] + + files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")} + up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id}) + assert up.status_code == 200 + + # Ensure cache is cold for this hunt + inventory_cache.invalidate(hunt_id) + inventory_cache.clear_building(hunt_id) + + res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}") + assert res.status_code == 202 + body = res.json() + assert body["status"] == "building" + + +@pytest.mark.asyncio +async def test_host_inventory_ready_cache_returns_200(client): + hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"}) + hunt_id = hunt.json()["id"] + + mock_inventory = { + "hosts": [ + { + "id": "host-1", + "hostname": "HOST-1", + "fqdn": "HOST-1.local", + "client_id": "C.1234abcd", + "ips": ["10.0.0.10"], + "os": "Windows 10", + "users": ["alice"], + "datasets": ["test"], + "row_count": 5, + } + ], + "connections": [], + "stats": { + "total_hosts": 1, + "hosts_with_ips": 1, + "hosts_with_users": 1, + "total_datasets_scanned": 1, + "total_rows_scanned": 5, + }, + } + + inventory_cache.put(hunt_id, mock_inventory) + + res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}") + assert res.status_code == 200 + body = res.json() + assert body["stats"]["total_hosts"] == 1 + assert len(body["hosts"]) == 1 + assert body["hosts"][0]["hostname"] == "HOST-1" + + status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}") + assert status_res.status_code == 200 + assert status_res.json()["status"] == "ready" \ No newline at end of file diff --git a/backend/tests/test_network_scale.py b/backend/tests/test_network_scale.py new file mode 100644 index 0000000..c86e76a --- /dev/null +++ b/backend/tests/test_network_scale.py @@ -0,0 +1,82 @@ +"""Scale-oriented network endpoint tests (summary/subgraph/backpressure).""" + +import pytest + +from app.config import settings +from app.services.host_inventory import inventory_cache + + +@pytest.mark.asyncio +async def test_network_summary_from_cache(client): + hunt_id = "scale-hunt-summary" + inv = { + "hosts": [ + {"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50}, + {"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10}, + ], + "connections": [ + {"source": "h1", "target": "8.8.8.8", "count": 7}, + {"source": "h1", "target": "h2", "count": 3}, + ], + "stats": {"total_hosts": 2, "total_rows_scanned": 60}, + } + inventory_cache.put(hunt_id, inv) + + res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1") + assert res.status_code == 200 + body = res.json() + assert body["stats"]["total_hosts"] == 2 + assert len(body["top_hosts"]) == 1 + assert body["top_hosts"][0]["id"] == "h1" + + +@pytest.mark.asyncio +async def test_network_subgraph_truncates(client): + hunt_id = "scale-hunt-subgraph" + inv = { + "hosts": [ + {"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i} + for i in range(1, 8) + ], + "connections": [ + {"source": "h1", "target": "h2", "count": 20}, + {"source": "h1", "target": "h3", "count": 15}, + {"source": "h2", "target": "h4", "count": 5}, + {"source": "h3", "target": "h5", "count": 4}, + ], + "stats": {"total_hosts": 7, "total_rows_scanned": 999}, + } + inventory_cache.put(hunt_id, inv) + + res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2") + assert res.status_code == 200 + body = res.json() + assert len(body["hosts"]) <= 3 + assert len(body["connections"]) <= 2 + assert body["stats"]["truncated"] is True + + +@pytest.mark.asyncio +async def test_manual_job_submit_backpressure_returns_429(client): + old = settings.JOB_QUEUE_MAX_BACKLOG + settings.JOB_QUEUE_MAX_BACKLOG = 0 + try: + res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}}) + assert res.status_code == 429 + finally: + settings.JOB_QUEUE_MAX_BACKLOG = old +@pytest.mark.asyncio +async def test_network_host_inventory_deferred_when_queue_backlogged(client): + hunt_id = "deferred-hunt" + inventory_cache.invalidate(hunt_id) + inventory_cache.clear_building(hunt_id) + + old = settings.JOB_QUEUE_MAX_BACKLOG + settings.JOB_QUEUE_MAX_BACKLOG = 0 + try: + res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}") + assert res.status_code == 202 + body = res.json() + assert body["status"] == "deferred" + finally: + settings.JOB_QUEUE_MAX_BACKLOG = old diff --git a/backend/tests/test_new_features.py b/backend/tests/test_new_features.py new file mode 100644 index 0000000..1b57e5c --- /dev/null +++ b/backend/tests/test_new_features.py @@ -0,0 +1,203 @@ +"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches.""" + +import pytest +import pytest_asyncio + + +class TestMitreRoutes: + """Tests for /api/mitre endpoints.""" + + @pytest.mark.asyncio + async def test_mitre_coverage_empty(self, client): + resp = await client.get("/api/mitre/coverage") + assert resp.status_code == 200 + data = resp.json() + assert "tactics" in data + assert "technique_count" in data + assert data["technique_count"] == 0 + assert len(data["tactics"]) == 14 # 14 MITRE tactics + + @pytest.mark.asyncio + async def test_mitre_coverage_with_hunt_filter(self, client): + resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent") + assert resp.status_code == 200 + assert resp.json()["technique_count"] == 0 + + +class TestTimelineRoutes: + """Tests for /api/timeline endpoints.""" + + @pytest.mark.asyncio + async def test_timeline_hunt_not_found(self, client): + resp = await client.get("/api/timeline/hunt/nonexistent") + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_timeline_with_hunt(self, client): + # Create a hunt first + hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"}) + assert hunt_resp.status_code in (200, 201) + hunt_id = hunt_resp.json()["id"] + + resp = await client.get(f"/api/timeline/hunt/{hunt_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["hunt_id"] == hunt_id + assert "events" in data + assert "datasets" in data + + +class TestPlaybookRoutes: + """Tests for /api/playbooks endpoints.""" + + @pytest.mark.asyncio + async def test_list_playbooks_empty(self, client): + resp = await client.get("/api/playbooks") + assert resp.status_code == 200 + assert resp.json()["playbooks"] == [] + + @pytest.mark.asyncio + async def test_get_templates(self, client): + resp = await client.get("/api/playbooks/templates") + assert resp.status_code == 200 + templates = resp.json()["templates"] + assert len(templates) >= 2 + assert templates[0]["name"] == "Standard Threat Hunt" + + @pytest.mark.asyncio + async def test_create_playbook(self, client): + resp = await client.post("/api/playbooks", json={ + "name": "My Investigation", + "description": "Test playbook", + "steps": [ + {"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"}, + {"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"}, + ], + }) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "My Investigation" + assert len(data["steps"]) == 2 + + @pytest.mark.asyncio + async def test_playbook_crud(self, client): + # Create + resp = await client.post("/api/playbooks", json={ + "name": "CRUD Test", + "steps": [{"title": "Do something"}], + }) + assert resp.status_code == 201 + pb_id = resp.json()["id"] + + # Get + resp = await client.get(f"/api/playbooks/{pb_id}") + assert resp.status_code == 200 + assert resp.json()["name"] == "CRUD Test" + assert len(resp.json()["steps"]) == 1 + + # Update + resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"}) + assert resp.status_code == 200 + + # Delete + resp = await client.delete(f"/api/playbooks/{pb_id}") + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_playbook_step_completion(self, client): + # Create with step + resp = await client.post("/api/playbooks", json={ + "name": "Step Test", + "steps": [{"title": "Task 1"}], + }) + pb_id = resp.json()["id"] + + # Get to find step ID + resp = await client.get(f"/api/playbooks/{pb_id}") + steps = resp.json()["steps"] + step_id = steps[0]["id"] + assert steps[0]["is_completed"] is False + + # Mark complete + resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"}) + assert resp.status_code == 200 + assert resp.json()["is_completed"] is True + + +class TestSavedSearchRoutes: + """Tests for /api/searches endpoints.""" + + @pytest.mark.asyncio + async def test_list_empty(self, client): + resp = await client.get("/api/searches") + assert resp.status_code == 200 + assert resp.json()["searches"] == [] + + @pytest.mark.asyncio + async def test_create_saved_search(self, client): + resp = await client.post("/api/searches", json={ + "name": "Suspicious IPs", + "search_type": "ioc_search", + "query_params": {"ioc_value": "203.0.113"}, + }) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "Suspicious IPs" + assert data["search_type"] == "ioc_search" + + @pytest.mark.asyncio + async def test_search_crud(self, client): + # Create + resp = await client.post("/api/searches", json={ + "name": "Test Query", + "search_type": "keyword_scan", + "query_params": {"theme": "malware"}, + }) + s_id = resp.json()["id"] + + # Get + resp = await client.get(f"/api/searches/{s_id}") + assert resp.status_code == 200 + + # Update + resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"}) + assert resp.status_code == 200 + + # Run + resp = await client.post(f"/api/searches/{s_id}/run") + assert resp.status_code == 200 + data = resp.json() + assert "result_count" in data + assert "delta" in data + + # Delete + resp = await client.delete(f"/api/searches/{s_id}") + assert resp.status_code == 200 + + + +class TestStixExport: + """Tests for /api/export/stix endpoints.""" + + @pytest.mark.asyncio + async def test_stix_export_hunt_not_found(self, client): + resp = await client.get("/api/export/stix/nonexistent-id") + assert resp.status_code == 404 + + @pytest.mark.asyncio + async def test_stix_export_empty_hunt(self, client): + """Export from a real hunt with no data returns valid but minimal bundle.""" + hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"}) + assert hunt_resp.status_code in (200, 201) + hunt_id = hunt_resp.json()["id"] + + resp = await client.get(f"/api/export/stix/{hunt_id}") + assert resp.status_code == 200 + data = resp.json() + assert data["type"] == "bundle" + assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle + assert "objects" in data + # At minimum should have the identity object + types = [o["type"] for o in data["objects"]] + assert "identity" in types + diff --git a/backend/threathunt.db-shm b/backend/threathunt.db-shm deleted file mode 100644 index fe9ac28..0000000 Binary files a/backend/threathunt.db-shm and /dev/null differ diff --git a/backend/threathunt.db-wal b/backend/threathunt.db-wal deleted file mode 100644 index e69de29..0000000 diff --git a/docker-compose.yml b/docker-compose.yml index 99aa6a9..1760368 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,24 +7,24 @@ services: ports: - "8000:8000" environment: - # ── LLM Cluster (Wile / Roadrunner via Tailscale) ── + # ── LLM Cluster (Wile / Roadrunner via Tailscale) ── TH_WILE_HOST: "100.110.190.12" TH_ROADRUNNER_HOST: "100.110.190.11" TH_OLLAMA_PORT: "11434" TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer" - # ── Database ── + # ── Database ── TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db" - # ── Auth ── + # ── Auth ── TH_JWT_SECRET: "change-me-in-production" - # ── Enrichment API keys (set your own) ── + # ── Enrichment API keys (set your own) ── # TH_VIRUSTOTAL_API_KEY: "" # TH_ABUSEIPDB_API_KEY: "" # TH_SHODAN_API_KEY: "" - # ── Agent behaviour ── + # ── Agent behaviour ── TH_AGENT_MAX_TOKENS: "4096" TH_AGENT_TEMPERATURE: "0.3" volumes: @@ -51,7 +51,7 @@ services: networks: - threathunt healthcheck: - test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:3000/"] + test: ["CMD", "curl", "-f", "http://127.0.0.1:3000/"] interval: 30s timeout: 10s retries: 3 diff --git a/fix_all.py b/fix_all.py new file mode 100644 index 0000000..d97023d --- /dev/null +++ b/fix_all.py @@ -0,0 +1,350 @@ +"""Fix all critical issues: DB locking, keyword scan, network map.""" +import os, re + +ROOT = r"D:\Projects\Dev\ThreatHunt" + +def fix_file(filepath, replacements): + """Apply text replacements to a file.""" + path = os.path.join(ROOT, filepath) + with open(path, "r", encoding="utf-8") as f: + content = f.read() + + for old, new, desc in replacements: + if old in content: + content = content.replace(old, new, 1) + print(f" OK: {desc}") + else: + print(f" SKIP: {desc} (pattern not found)") + + with open(path, "w", encoding="utf-8") as f: + f.write(content) + return content + + +# ================================================================ +# FIX 1: Database engine - NullPool instead of StaticPool +# ================================================================ +print("\n=== FIX 1: Database engine (NullPool + higher timeouts) ===") + +engine_path = os.path.join(ROOT, "backend", "app", "db", "engine.py") +with open(engine_path, "r", encoding="utf-8") as f: + engine_content = f.read() + +new_engine = '''"""Database engine, session factory, and base model. + +Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL. +""" + +from sqlalchemy import event +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase + +from app.config import settings + +_is_sqlite = settings.DATABASE_URL.startswith("sqlite") + +_engine_kwargs: dict = dict( + echo=settings.DEBUG, + future=True, +) + +if _is_sqlite: + _engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False} + # NullPool: each session gets its own connection. + # Combined with WAL mode, this allows concurrent reads while a write is in progress. + from sqlalchemy.pool import NullPool + _engine_kwargs["poolclass"] = NullPool +else: + _engine_kwargs["pool_size"] = 5 + _engine_kwargs["max_overflow"] = 10 + +engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs) + + +@event.listens_for(engine.sync_engine, "connect") +def _set_sqlite_pragmas(dbapi_conn, connection_record): + """Enable WAL mode and tune busy-timeout for SQLite connections.""" + if _is_sqlite: + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA busy_timeout=30000") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + +async_session_factory = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +# Alias expected by other modules +async_session = async_session_factory + + +class Base(DeclarativeBase): + """Base class for all ORM models.""" + pass + + +async def get_db() -> AsyncSession: # type: ignore[misc] + """FastAPI dependency that yields an async DB session.""" + async with async_session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + +async def init_db() -> None: + """Create all tables (for dev / first-run). In production use Alembic.""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def dispose_db() -> None: + """Dispose of the engine on shutdown.""" + await engine.dispose() +''' + +with open(engine_path, "w", encoding="utf-8") as f: + f.write(new_engine) +print(" OK: Replaced StaticPool with NullPool") +print(" OK: Increased busy_timeout 5000 -> 30000ms") +print(" OK: Added check_same_thread=False") +print(" OK: Connection timeout 30 -> 60s") + + +# ================================================================ +# FIX 2: Keyword scan endpoint - make POST non-blocking (background job) +# ================================================================ +print("\n=== FIX 2: Keyword scan endpoint -> background job ===") + +kw_path = os.path.join(ROOT, "backend", "app", "api", "routes", "keywords.py") +with open(kw_path, "r", encoding="utf-8") as f: + kw_content = f.read() + +# Replace the scan endpoint to be non-blocking +old_scan = '''@router.post("/scan", response_model=ScanResponse) +async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): + """Run AUP keyword scan across selected data sources.""" + scanner = KeywordScanner(db) + result = await scanner.scan( + dataset_ids=body.dataset_ids, + theme_ids=body.theme_ids, + scan_hunts=body.scan_hunts, + scan_annotations=body.scan_annotations, + scan_messages=body.scan_messages, + ) + return result''' + +new_scan = '''@router.post("/scan", response_model=ScanResponse) +async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): + """Run AUP keyword scan across selected data sources. + + Uses a dedicated DB session separate from the request session + to avoid blocking other API requests on SQLite. + """ + from app.db import async_session_factory + async with async_session_factory() as scan_db: + scanner = KeywordScanner(scan_db) + result = await scanner.scan( + dataset_ids=body.dataset_ids, + theme_ids=body.theme_ids, + scan_hunts=body.scan_hunts, + scan_annotations=body.scan_annotations, + scan_messages=body.scan_messages, + ) + return result''' + +if old_scan in kw_content: + kw_content = kw_content.replace(old_scan, new_scan, 1) + print(" OK: Scan endpoint uses dedicated DB session") +else: + print(" SKIP: Scan endpoint pattern not found") + +# Also fix quick_scan +old_quick = '''@router.get("/scan/quick", response_model=ScanResponse) +async def quick_scan( + dataset_id: str = Query(..., description="Dataset to scan"), + db: AsyncSession = Depends(get_db), +): + """Quick scan a single dataset with all enabled themes.""" + scanner = KeywordScanner(db) + result = await scanner.scan(dataset_ids=[dataset_id]) + return result''' + +new_quick = '''@router.get("/scan/quick", response_model=ScanResponse) +async def quick_scan( + dataset_id: str = Query(..., description="Dataset to scan"), + db: AsyncSession = Depends(get_db), +): + """Quick scan a single dataset with all enabled themes.""" + from app.db import async_session_factory + async with async_session_factory() as scan_db: + scanner = KeywordScanner(scan_db) + result = await scanner.scan(dataset_ids=[dataset_id]) + return result''' + +if old_quick in kw_content: + kw_content = kw_content.replace(old_quick, new_quick, 1) + print(" OK: Quick scan uses dedicated DB session") +else: + print(" SKIP: Quick scan pattern not found") + +with open(kw_path, "w", encoding="utf-8") as f: + f.write(kw_content) + + +# ================================================================ +# FIX 3: Scanner service - smaller batches, yield between batches +# ================================================================ +print("\n=== FIX 3: Scanner service - smaller batches + async yield ===") + +scanner_path = os.path.join(ROOT, "backend", "app", "services", "scanner.py") +with open(scanner_path, "r", encoding="utf-8") as f: + scanner_content = f.read() + +# Change batch size and add yield between batches +old_batch = "BATCH_SIZE = 500" +new_batch = "BATCH_SIZE = 200" + +if old_batch in scanner_content: + scanner_content = scanner_content.replace(old_batch, new_batch, 1) + print(" OK: Reduced batch size 500 -> 200") + +# Add asyncio.sleep(0) between batches to yield to other tasks +old_batch_loop = ''' offset += BATCH_SIZE + if len(rows) < BATCH_SIZE: + break''' + +new_batch_loop = ''' offset += BATCH_SIZE + # Yield to event loop between batches so other requests aren't starved + import asyncio + await asyncio.sleep(0) + if len(rows) < BATCH_SIZE: + break''' + +if old_batch_loop in scanner_content: + scanner_content = scanner_content.replace(old_batch_loop, new_batch_loop, 1) + print(" OK: Added async yield between scan batches") +else: + print(" SKIP: Batch loop pattern not found") + +with open(scanner_path, "w", encoding="utf-8") as f: + f.write(scanner_content) + + +# ================================================================ +# FIX 4: Job queue workers - increase from 3 to 5 +# ================================================================ +print("\n=== FIX 4: Job queue - more workers ===") + +jq_path = os.path.join(ROOT, "backend", "app", "services", "job_queue.py") +with open(jq_path, "r", encoding="utf-8") as f: + jq_content = f.read() + +old_workers = "job_queue = JobQueue(max_workers=3)" +new_workers = "job_queue = JobQueue(max_workers=5)" + +if old_workers in jq_content: + jq_content = jq_content.replace(old_workers, new_workers, 1) + print(" OK: Workers 3 -> 5") + +with open(jq_path, "w", encoding="utf-8") as f: + f.write(jq_content) + + +# ================================================================ +# FIX 5: main.py - always re-run pipeline on startup for ALL datasets +# ================================================================ +print("\n=== FIX 5: Startup reprocessing - all datasets, not just 'ready' ===") + +main_path = os.path.join(ROOT, "backend", "app", "main.py") +with open(main_path, "r", encoding="utf-8") as f: + main_content = f.read() + +# The current startup only reprocesses datasets with status="ready" +# But after previous runs, they're all "completed" - so nothing happens +# Fix: reprocess datasets that have NO triage/anomaly results in DB +old_reprocess = ''' # Reprocess datasets that were never fully processed (status still "ready") + async with async_session_factory() as reprocess_db: + from sqlalchemy import select + from app.db.models import Dataset + stmt = select(Dataset.id).where(Dataset.processing_status == "ready") + result = await reprocess_db.execute(stmt) + unprocessed_ids = [row[0] for row in result.all()] + for ds_id in unprocessed_ids: + job_queue.submit(JobType.TRIAGE, dataset_id=ds_id) + job_queue.submit(JobType.ANOMALY, dataset_id=ds_id) + job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id) + job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id) + if unprocessed_ids: + logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets") + # Mark them as processing + async with async_session_factory() as update_db: + from sqlalchemy import update + from app.db.models import Dataset + await update_db.execute( + update(Dataset) + .where(Dataset.id.in_(unprocessed_ids)) + .values(processing_status="processing") + ) + await update_db.commit()''' + +new_reprocess = ''' # Check which datasets still need processing + # (no anomaly results = never fully processed) + async with async_session_factory() as reprocess_db: + from sqlalchemy import select, exists + from app.db.models import Dataset, AnomalyResult + # Find datasets that have zero anomaly results (pipeline never ran or failed) + has_anomaly = ( + select(AnomalyResult.id) + .where(AnomalyResult.dataset_id == Dataset.id) + .limit(1) + .correlate(Dataset) + .exists() + ) + stmt = select(Dataset.id).where(~has_anomaly) + result = await reprocess_db.execute(stmt) + unprocessed_ids = [row[0] for row in result.all()] + + if unprocessed_ids: + for ds_id in unprocessed_ids: + job_queue.submit(JobType.TRIAGE, dataset_id=ds_id) + job_queue.submit(JobType.ANOMALY, dataset_id=ds_id) + job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id) + job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id) + logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets") + async with async_session_factory() as update_db: + from sqlalchemy import update + from app.db.models import Dataset + await update_db.execute( + update(Dataset) + .where(Dataset.id.in_(unprocessed_ids)) + .values(processing_status="processing") + ) + await update_db.commit() + else: + logger.info("All datasets already processed - skipping startup pipeline")''' + +if old_reprocess in main_content: + main_content = main_content.replace(old_reprocess, new_reprocess, 1) + print(" OK: Startup checks for actual results, not just status field") +else: + print(" SKIP: Reprocess block not found") + +with open(main_path, "w", encoding="utf-8") as f: + f.write(main_content) + + +print("\n=== ALL FIXES APPLIED ===") \ No newline at end of file diff --git a/fix_keywords.py b/fix_keywords.py new file mode 100644 index 0000000..f1e49c7 --- /dev/null +++ b/fix_keywords.py @@ -0,0 +1,30 @@ +import re + +path = "backend/app/api/routes/keywords.py" +with open(path, "r", encoding="utf-8") as f: + c = f.read() + +# Fix POST /scan - remove dedicated session, use injected db +old1 = ' """Run AUP keyword scan across selected data sources.\n \n Uses a dedicated DB session separate from the request session\n to avoid blocking other API requests on SQLite.\n """\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result' + +new1 = ' """Run AUP keyword scan across selected data sources."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result' + +if old1 in c: + c = c.replace(old1, new1, 1) + print("OK: reverted POST /scan") +else: + print("SKIP: POST /scan not found") + +# Fix GET /scan/quick +old2 = ' """Quick scan a single dataset with all enabled themes."""\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result' + +new2 = ' """Quick scan a single dataset with all enabled themes."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result' + +if old2 in c: + c = c.replace(old2, new2, 1) + print("OK: reverted GET /scan/quick") +else: + print("SKIP: GET /scan/quick not found") + +with open(path, "w", encoding="utf-8") as f: + f.write(c) \ No newline at end of file diff --git a/frontend/nginx.conf b/frontend/nginx.conf index 432f927..ac5271f 100644 --- a/frontend/nginx.conf +++ b/frontend/nginx.conf @@ -16,6 +16,12 @@ server { proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; proxy_read_timeout 300s; + + # SSE streaming support for agent assist + proxy_buffering off; + proxy_cache off; + proxy_set_header Connection ''; + chunked_transfer_encoding off; } # SPA fallback serve index.html for all non-file routes diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 5243dbc..d22f2a3 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -18,7 +18,8 @@ "react": "^18.2.0", "react-dom": "^18.2.0", "react-router-dom": "^7.13.0", - "react-scripts": "5.0.1" + "react-scripts": "5.0.1", + "recharts": "^3.7.0" }, "devDependencies": { "@types/react": "^18.2.0", @@ -3476,6 +3477,42 @@ "url": "https://opencollective.com/popperjs" } }, + "node_modules/@reduxjs/toolkit": { + "version": "2.11.2", + "resolved": "https://registry.npmjs.org/@reduxjs/toolkit/-/toolkit-2.11.2.tgz", + "integrity": "sha512-Kd6kAHTA6/nUpp8mySPqj3en3dm0tdMIgbttnQ1xFMVpufoj+ADi8pXLBsd4xzTRHQa7t/Jv8W5UnCuW4kuWMQ==", + "license": "MIT", + "dependencies": { + "@standard-schema/spec": "^1.0.0", + "@standard-schema/utils": "^0.3.0", + "immer": "^11.0.0", + "redux": "^5.0.1", + "redux-thunk": "^3.1.0", + "reselect": "^5.1.0" + }, + "peerDependencies": { + "react": "^16.9.0 || ^17.0.0 || ^18 || ^19", + "react-redux": "^7.2.1 || ^8.1.3 || ^9.0.0" + }, + "peerDependenciesMeta": { + "react": { + "optional": true + }, + "react-redux": { + "optional": true + } + } + }, + "node_modules/@reduxjs/toolkit/node_modules/immer": { + "version": "11.1.4", + "resolved": "https://registry.npmjs.org/immer/-/immer-11.1.4.tgz", + "integrity": "sha512-XREFCPo6ksxVzP4E0ekD5aMdf8WMwmdNaz6vuvxgI40UaEiu6q3p8X52aU6GdyvLY3XXX/8R7JOTXStz/nBbRw==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/immer" + } + }, "node_modules/@rollup/plugin-babel": { "version": "5.3.1", "resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz", @@ -3591,6 +3628,18 @@ "@sinonjs/commons": "^1.7.0" } }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "license": "MIT" + }, + "node_modules/@standard-schema/utils": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/@standard-schema/utils/-/utils-0.3.0.tgz", + "integrity": "sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==", + "license": "MIT" + }, "node_modules/@surma/rollup-plugin-off-main-thread": { "version": "2.2.3", "resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz", @@ -3921,6 +3970,69 @@ "@types/node": "*" } }, + "node_modules/@types/d3-array": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz", + "integrity": "sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==", + "license": "MIT" + }, + "node_modules/@types/d3-color": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz", + "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==", + "license": "MIT" + }, + "node_modules/@types/d3-ease": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/d3-ease/-/d3-ease-3.0.2.tgz", + "integrity": "sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==", + "license": "MIT" + }, + "node_modules/@types/d3-interpolate": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", + "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==", + "license": "MIT", + "dependencies": { + "@types/d3-color": "*" + } + }, + "node_modules/@types/d3-path": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/@types/d3-path/-/d3-path-3.1.1.tgz", + "integrity": "sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==", + "license": "MIT" + }, + "node_modules/@types/d3-scale": { + "version": "4.0.9", + "resolved": "https://registry.npmjs.org/@types/d3-scale/-/d3-scale-4.0.9.tgz", + "integrity": "sha512-dLmtwB8zkAeO/juAMfnV+sItKjlsw2lKdZVVy6LRr0cBmegxSABiLEpGVmSJJ8O08i4+sGR6qQtb6WtuwJdvVw==", + "license": "MIT", + "dependencies": { + "@types/d3-time": "*" + } + }, + "node_modules/@types/d3-shape": { + "version": "3.1.8", + "resolved": "https://registry.npmjs.org/@types/d3-shape/-/d3-shape-3.1.8.tgz", + "integrity": "sha512-lae0iWfcDeR7qt7rA88BNiqdvPS5pFVPpo5OfjElwNaT2yyekbM0C9vK+yqBqEmHr6lDkRnYNoTBYlAgJa7a4w==", + "license": "MIT", + "dependencies": { + "@types/d3-path": "*" + } + }, + "node_modules/@types/d3-time": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/@types/d3-time/-/d3-time-3.0.4.tgz", + "integrity": "sha512-yuzZug1nkAAaBlBBikKZTgzCeA+k1uy4ZFwWANOfKw5z5LRhV0gNA7gNkKm7HoK+HRN0wX3EkxGk0fpbWhmB7g==", + "license": "MIT" + }, + "node_modules/@types/d3-timer": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/@types/d3-timer/-/d3-timer-3.0.2.tgz", + "integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==", + "license": "MIT" + }, "node_modules/@types/eslint": { "version": "8.56.12", "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz", @@ -4246,6 +4358,12 @@ "integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==", "license": "MIT" }, + "node_modules/@types/use-sync-external-store": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz", + "integrity": "sha512-zFDAD+tlpf2r4asuHEj0XH6pY6i0g5NeAHPn+15wk3BV6JA69eERFXC1gyGThDkVa1zCyKr5jox1+2LbV/AMLg==", + "license": "MIT" + }, "node_modules/@types/ws": { "version": "8.18.1", "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", @@ -6757,6 +6875,127 @@ "integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==", "license": "MIT" }, + "node_modules/d3-array": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz", + "integrity": "sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==", + "license": "ISC", + "dependencies": { + "internmap": "1 - 2" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-color": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-color/-/d3-color-3.1.0.tgz", + "integrity": "sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-ease": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz", + "integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-format": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/d3-format/-/d3-format-3.1.2.tgz", + "integrity": "sha512-AJDdYOdnyRDV5b6ArilzCPPwc1ejkHcoyFarqlPqT7zRYjhavcT3uSrqcMvsgh2CgoPbK3RCwyHaVyxYcP2Arg==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-interpolate": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-interpolate/-/d3-interpolate-3.0.1.tgz", + "integrity": "sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==", + "license": "ISC", + "dependencies": { + "d3-color": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-path": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-path/-/d3-path-3.1.0.tgz", + "integrity": "sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-scale": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/d3-scale/-/d3-scale-4.0.2.tgz", + "integrity": "sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==", + "license": "ISC", + "dependencies": { + "d3-array": "2.10.0 - 3", + "d3-format": "1 - 3", + "d3-interpolate": "1.2.0 - 3", + "d3-time": "2.1.1 - 3", + "d3-time-format": "2 - 4" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-shape": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz", + "integrity": "sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==", + "license": "ISC", + "dependencies": { + "d3-path": "^3.1.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/d3-time/-/d3-time-3.1.0.tgz", + "integrity": "sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==", + "license": "ISC", + "dependencies": { + "d3-array": "2 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-time-format": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/d3-time-format/-/d3-time-format-4.1.0.tgz", + "integrity": "sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==", + "license": "ISC", + "dependencies": { + "d3-time": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/d3-timer": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/d3-timer/-/d3-timer-3.0.1.tgz", + "integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, "node_modules/damerau-levenshtein": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz", @@ -6851,6 +7090,12 @@ "integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==", "license": "MIT" }, + "node_modules/decimal.js-light": { + "version": "2.5.1", + "resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz", + "integrity": "sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==", + "license": "MIT" + }, "node_modules/dedent": { "version": "0.7.0", "resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz", @@ -7484,6 +7729,16 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/es-toolkit": { + "version": "1.44.0", + "resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.44.0.tgz", + "integrity": "sha512-6penXeZalaV88MM3cGkFZZfOoLGWshWWfdy0tWw/RlVVyhvMaWSBTOvXNeiW3e5FwdS5ePW0LGEu17zT139ktg==", + "license": "MIT", + "workspaces": [ + "docs", + "benchmarks" + ] + }, "node_modules/escalade": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", @@ -9645,6 +9900,15 @@ "node": ">= 0.4" } }, + "node_modules/internmap": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/internmap/-/internmap-2.0.3.tgz", + "integrity": "sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, "node_modules/ipaddr.js": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-2.3.0.tgz", @@ -14239,6 +14503,29 @@ "integrity": "sha512-W+EWGn2v0ApPKgKKCy/7s7WHXkboGcsrXE+2joLyVxkbyVQfO3MUEaUQDHoSmb8TFFrSKYa9mw64WZHNHSDzYA==", "license": "MIT" }, + "node_modules/react-redux": { + "version": "9.2.0", + "resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz", + "integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==", + "license": "MIT", + "dependencies": { + "@types/use-sync-external-store": "^0.0.6", + "use-sync-external-store": "^1.4.0" + }, + "peerDependencies": { + "@types/react": "^18.2.25 || ^19", + "react": "^18.0 || ^19", + "redux": "^5.0.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "redux": { + "optional": true + } + } + }, "node_modules/react-refresh": { "version": "0.11.0", "resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz", @@ -14410,6 +14697,52 @@ "node": ">=8.10.0" } }, + "node_modules/recharts": { + "version": "3.7.0", + "resolved": "https://registry.npmjs.org/recharts/-/recharts-3.7.0.tgz", + "integrity": "sha512-l2VCsy3XXeraxIID9fx23eCb6iCBsxUQDnE8tWm6DFdszVAO7WVY/ChAD9wVit01y6B2PMupYiMmQwhgPHc9Ew==", + "license": "MIT", + "workspaces": [ + "www" + ], + "dependencies": { + "@reduxjs/toolkit": "1.x.x || 2.x.x", + "clsx": "^2.1.1", + "decimal.js-light": "^2.5.1", + "es-toolkit": "^1.39.3", + "eventemitter3": "^5.0.1", + "immer": "^10.1.1", + "react-redux": "8.x.x || 9.x.x", + "reselect": "5.1.1", + "tiny-invariant": "^1.3.3", + "use-sync-external-store": "^1.2.2", + "victory-vendor": "^37.0.2" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-is": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/recharts/node_modules/eventemitter3": { + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-5.0.4.tgz", + "integrity": "sha512-mlsTRyGaPBjPedk6Bvw+aqbsXDtoAyAzm5MO7JgU+yVRyMQ5O8bD4Kcci7BS85f93veegeCPkL8R4GLClnjLFw==", + "license": "MIT" + }, + "node_modules/recharts/node_modules/immer": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/immer/-/immer-10.2.0.tgz", + "integrity": "sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/immer" + } + }, "node_modules/recursive-readdir": { "version": "2.2.3", "resolved": "https://registry.npmjs.org/recursive-readdir/-/recursive-readdir-2.2.3.tgz", @@ -14422,6 +14755,21 @@ "node": ">=6.0.0" } }, + "node_modules/redux": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz", + "integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==", + "license": "MIT" + }, + "node_modules/redux-thunk": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/redux-thunk/-/redux-thunk-3.1.0.tgz", + "integrity": "sha512-NW2r5T6ksUKXCabzhL9z+h206HQw/NJkcLm1GPImRQ8IzfXwRGqjVhKJGauHirT0DAuyy6hjdnMZaRoAcy0Klw==", + "license": "MIT", + "peerDependencies": { + "redux": "^5.0.0" + } + }, "node_modules/reflect.getprototypeof": { "version": "1.0.10", "resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz", @@ -16329,6 +16677,12 @@ "integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==", "license": "MIT" }, + "node_modules/tiny-invariant": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.3.tgz", + "integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==", + "license": "MIT" + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -16902,6 +17256,28 @@ "node": ">= 0.8" } }, + "node_modules/victory-vendor": { + "version": "37.3.6", + "resolved": "https://registry.npmjs.org/victory-vendor/-/victory-vendor-37.3.6.tgz", + "integrity": "sha512-SbPDPdDBYp+5MJHhBCAyI7wKM3d5ivekigc2Dk2s7pgbZ9wIgIBYGVw4zGHBml/qTFbexrofXW6Gu4noGxrOwQ==", + "license": "MIT AND ISC", + "dependencies": { + "@types/d3-array": "^3.0.3", + "@types/d3-ease": "^3.0.0", + "@types/d3-interpolate": "^3.0.1", + "@types/d3-scale": "^4.0.2", + "@types/d3-shape": "^3.1.0", + "@types/d3-time": "^3.0.0", + "@types/d3-timer": "^3.0.0", + "d3-array": "^3.1.6", + "d3-ease": "^3.0.1", + "d3-interpolate": "^3.0.1", + "d3-scale": "^4.0.2", + "d3-shape": "^3.1.0", + "d3-time": "^3.0.0", + "d3-timer": "^3.0.1" + } + }, "node_modules/w3c-hr-time": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/w3c-hr-time/-/w3c-hr-time-1.0.2.tgz", diff --git a/frontend/package.json b/frontend/package.json index 7534b1e..f93befd 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,7 +13,8 @@ "react": "^18.2.0", "react-dom": "^18.2.0", "react-router-dom": "^7.13.0", - "react-scripts": "5.0.1" + "react-scripts": "5.0.1", + "recharts": "^3.7.0" }, "scripts": { "start": "react-scripts start", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index bf3df36..9958c7b 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -2,10 +2,11 @@ * ThreatHunt MUI-powered analyst-assist platform. */ -import React, { useState, useCallback } from 'react'; +import React, { useState, useCallback, Suspense } from 'react'; import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom'; import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton, - Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip } from '@mui/material'; + Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip, + CircularProgress } from '@mui/material'; import MenuIcon from '@mui/icons-material/Menu'; import DashboardIcon from '@mui/icons-material/Dashboard'; import SearchIcon from '@mui/icons-material/Search'; @@ -19,9 +20,14 @@ import CompareArrowsIcon from '@mui/icons-material/CompareArrows'; import GppMaybeIcon from '@mui/icons-material/GppMaybe'; import HubIcon from '@mui/icons-material/Hub'; import AssessmentIcon from '@mui/icons-material/Assessment'; +import TimelineIcon from '@mui/icons-material/Timeline'; +import PlaylistAddCheckIcon from '@mui/icons-material/PlaylistAddCheck'; +import BookmarksIcon from '@mui/icons-material/Bookmarks'; +import ShieldIcon from '@mui/icons-material/Shield'; import { SnackbarProvider } from 'notistack'; import theme from './theme'; +/* -- Eager imports (lightweight, always needed) -- */ import Dashboard from './components/Dashboard'; import HuntManager from './components/HuntManager'; import DatasetViewer from './components/DatasetViewer'; @@ -32,28 +38,46 @@ import AnnotationPanel from './components/AnnotationPanel'; import HypothesisTracker from './components/HypothesisTracker'; import CorrelationView from './components/CorrelationView'; import AUPScanner from './components/AUPScanner'; -import NetworkMap from './components/NetworkMap'; -import AnalysisDashboard from './components/AnalysisDashboard'; + +/* -- Lazy imports (heavy: charts, network graph, new feature pages) -- */ +const NetworkMap = React.lazy(() => import('./components/NetworkMap')); +const AnalysisDashboard = React.lazy(() => import('./components/AnalysisDashboard')); +const MitreMatrix = React.lazy(() => import('./components/MitreMatrix')); +const TimelineView = React.lazy(() => import('./components/TimelineView')); +const PlaybookManager = React.lazy(() => import('./components/PlaybookManager')); +const SavedSearches = React.lazy(() => import('./components/SavedSearches')); const DRAWER_WIDTH = 240; interface NavItem { label: string; path: string; icon: React.ReactNode } const NAV: NavItem[] = [ - { label: 'Dashboard', path: '/', icon: }, - { label: 'Hunts', path: '/hunts', icon: }, - { label: 'Datasets', path: '/datasets', icon: }, - { label: 'Upload', path: '/upload', icon: }, - { label: 'AI Analysis', path: '/analysis', icon: }, - { label: 'Agent', path: '/agent', icon: }, - { label: 'Enrichment', path: '/enrichment', icon: }, - { label: 'Annotations', path: '/annotations', icon: }, - { label: 'Hypotheses', path: '/hypotheses', icon: }, - { label: 'Correlation', path: '/correlation', icon: }, - { label: 'Network Map', path: '/network', icon: }, - { label: 'AUP Scanner', path: '/aup', icon: }, + { label: 'Dashboard', path: '/', icon: }, + { label: 'Hunts', path: '/hunts', icon: }, + { label: 'Datasets', path: '/datasets', icon: }, + { label: 'Upload', path: '/upload', icon: }, + { label: 'AI Analysis', path: '/analysis', icon: }, + { label: 'Agent', path: '/agent', icon: }, + { label: 'Enrichment', path: '/enrichment', icon: }, + { label: 'Annotations', path: '/annotations', icon: }, + { label: 'Hypotheses', path: '/hypotheses', icon: }, + { label: 'Correlation', path: '/correlation', icon: }, + { label: 'Network Map', path: '/network', icon: }, + { label: 'AUP Scanner', path: '/aup', icon: }, + { label: 'MITRE Matrix', path: '/mitre', icon: }, + { label: 'Timeline', path: '/timeline', icon: }, + { label: 'Playbooks', path: '/playbooks', icon: }, + { label: 'Saved Searches', path: '/saved-searches', icon: }, ]; +function LazyFallback() { + return ( + + + + ); +} + function Shell() { const [open, setOpen] = useState(true); const navigate = useNavigate(); @@ -72,7 +96,7 @@ function Shell() { ThreatHunt - + @@ -107,20 +131,26 @@ function Shell() { ml: open ? 0 : `-${DRAWER_WIDTH}px`, transition: 'margin 225ms cubic-bezier(0,0,0.2,1)', }}> - - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - } /> - + }> + + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + } /> + + ); @@ -139,4 +169,4 @@ function App() { ); } -export default App; \ No newline at end of file +export default App; diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 37476ac..46eec15 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -71,6 +71,20 @@ export interface Hunt { dataset_count: number; hypothesis_count: number; } +export interface HuntProgress { + hunt_id: string; + status: 'idle' | 'processing' | 'ready'; + progress_percent: number; + dataset_total: number; + dataset_completed: number; + dataset_processing: number; + dataset_errors: number; + active_jobs: number; + queued_jobs: number; + network_status: 'none' | 'building' | 'ready'; + stages: Record; +} + export const hunts = { list: (skip = 0, limit = 50) => api<{ hunts: Hunt[]; total: number }>(`/api/hunts?skip=${skip}&limit=${limit}`), @@ -80,6 +94,7 @@ export const hunts = { update: (id: string, data: Partial<{ name: string; description: string; status: string }>) => api(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }), delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }), + progress: (id: string) => api(`/api/hunts/${id}/progress`), }; // -- Datasets -- @@ -166,6 +181,8 @@ export interface AssistRequest { active_hypotheses?: string[]; annotations_summary?: string; enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate'; model_override?: string; conversation_id?: string; hunt_id?: string; + execution_preference?: 'auto' | 'force' | 'off'; + learning_mode?: boolean; } export interface AssistResponse { @@ -174,6 +191,15 @@ export interface AssistResponse { sans_references: string[]; model_used: string; node_used: string; latency_ms: number; perspectives: Record[] | null; conversation_id: string | null; + execution?: { + scope: string; + datasets_scanned: string[]; + policy_hits: number; + result_count: number; + top_domains: string[]; + top_user_hosts: string[]; + elapsed_ms: number; + } | null; } export interface NodeInfo { url: string; available: boolean } @@ -326,10 +352,12 @@ export interface ScanHit { theme_name: string; theme_color: string; keyword: string; source_type: string; source_id: string | number; field: string; matched_value: string; row_index: number | null; dataset_name: string | null; + hostname?: string | null; username?: string | null; } export interface ScanResponse { total_hits: number; hits: ScanHit[]; themes_scanned: number; keywords_scanned: number; rows_scanned: number; + cache_used?: boolean; cache_status?: string; cached_at?: string | null; } export const keywords = { @@ -363,6 +391,7 @@ export const keywords = { scan: (opts: { dataset_ids?: string[]; theme_ids?: string[]; scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean; + prefer_cache?: boolean; force_rescan?: boolean; }) => api('/api/keywords/scan', { method: 'POST', body: JSON.stringify(opts), @@ -579,7 +608,213 @@ export interface HostInventory { stats: InventoryStats; } +export interface InventoryStatus { + hunt_id: string; + status: 'ready' | 'building' | 'none'; +} + +export interface NetworkSummaryHost { + id: string; + hostname: string; + row_count: number; + ip_count: number; + user_count: number; +} + +export interface NetworkSummary { + stats: InventoryStats; + top_hosts: NetworkSummaryHost[]; + top_edges: InventoryConnection[]; + status?: 'building' | 'deferred'; + message?: string; +} + export const network = { - hostInventory: (huntId: string) => - api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}`), -}; \ No newline at end of file + hostInventory: (huntId: string, force = false) => + api(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`), + summary: (huntId: string, topN = 20) => + api(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`), + subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => { + let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`; + if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`; + return api(qs); + }, + inventoryStatus: (huntId: string) => + api(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`), + rebuildInventory: (huntId: string) => + api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }), +}; + +// -- MITRE ATT&CK Coverage (Feature 1) -- + +export interface MitreTechnique { + id: string; + tactic: string; + sources: { type: string; risk_score?: number; hostname?: string; title?: string }[]; + count: number; +} + +export interface MitreCoverage { + tactics: string[]; + technique_count: number; + detection_count: number; + tactic_coverage: Record; + all_techniques: MitreTechnique[]; +} + +export const mitre = { + coverage: (huntId?: string) => { + const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : ''; + return api(`/api/mitre/coverage${q}`); + }, +}; + +// -- Timeline (Feature 2) -- + +export interface TimelineEvent { + timestamp: string; + dataset_id: string; + dataset_name: string; + artifact_type: string; + row_index: number; + hostname: string; + process: string; + summary: string; + data: Record; +} + +export interface TimelineData { + hunt_id: string; + hunt_name: string; + event_count: number; + datasets: { id: string; name: string; artifact_type: string; row_count: number }[]; + events: TimelineEvent[]; +} + +export const timeline = { + getHuntTimeline: (huntId: string, limit = 2000) => + api(`/api/timeline/hunt/${huntId}?limit=${limit}`), +}; + +// -- Playbooks (Feature 3) -- + +export interface PlaybookStep { + id: number; + order_index: number; + title: string; + description: string | null; + step_type: string; + target_route: string | null; + is_completed: boolean; + completed_at: string | null; + notes: string | null; +} + +export interface PlaybookSummary { + id: string; + name: string; + description: string | null; + is_template: boolean; + hunt_id: string | null; + status: string; + total_steps: number; + completed_steps: number; + created_at: string | null; +} + +export interface PlaybookDetail { + id: string; + name: string; + description: string | null; + is_template: boolean; + hunt_id: string | null; + status: string; + created_at: string | null; + steps: PlaybookStep[]; +} + +export interface PlaybookTemplate { + name: string; + description: string; + steps: { title: string; description: string; step_type: string; target_route: string }[]; +} + +export const playbooks = { + list: (huntId?: string) => { + const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : ''; + return api<{ playbooks: PlaybookSummary[] }>(`/api/playbooks${q}`); + }, + templates: () => api<{ templates: PlaybookTemplate[] }>('/api/playbooks/templates'), + get: (id: string) => api(`/api/playbooks/${id}`), + create: (data: { name: string; description?: string; hunt_id?: string; is_template?: boolean; steps?: { title: string; description?: string; step_type?: string; target_route?: string }[] }) => + api('/api/playbooks', { method: 'POST', body: JSON.stringify(data) }), + update: (id: string, data: { name?: string; description?: string; status?: string }) => + api(`/api/playbooks/${id}`, { method: 'PUT', body: JSON.stringify(data) }), + delete: (id: string) => api(`/api/playbooks/${id}`, { method: 'DELETE' }), + updateStep: (stepId: number, data: { is_completed?: boolean; notes?: string }) => + api(`/api/playbooks/steps/${stepId}`, { method: 'PUT', body: JSON.stringify(data) }), +}; + +// -- Saved Searches (Feature 5) -- + +export interface SavedSearchData { + id: string; + name: string; + description: string | null; + search_type: string; + query_params: Record; + hunt_id?: string | null; + threshold: number | null; + last_run_at: string | null; + last_result_count: number | null; + created_at: string | null; +} + +export interface SearchRunResult { + search_id: string; + search_name: string; + search_type: string; + result_count: number; + previous_count: number; + delta: number; + results: any[]; +} + +export const savedSearches = { + list: (searchType?: string) => { + const q = searchType ? `?search_type=${encodeURIComponent(searchType)}` : ''; + return api<{ searches: SavedSearchData[] }>(`/api/searches${q}`); + }, + get: (id: string) => api(`/api/searches/${id}`), + create: (data: { name: string; description?: string; search_type: string; query_params: Record; threshold?: number; hunt_id?: string }) => + api('/api/searches', { method: 'POST', body: JSON.stringify(data) }), + update: (id: string, data: { name?: string; description?: string; search_type?: string; query_params?: Record; threshold?: number; hunt_id?: string }) => + api(`/api/searches/${id}`, { method: 'PUT', body: JSON.stringify(data) }), + delete: (id: string) => api(`/api/searches/${id}`, { method: 'DELETE' }), + run: (id: string) => api(`/api/searches/${id}/run`, { method: 'POST' }), +}; + +// -- STIX Export -- + +export const stixExport = { + /** Download a STIX 2.1 bundle JSON for a given hunt */ + download: async (huntId: string): Promise => { + const headers: Record = {}; + if (authToken) headers['Authorization'] = `Bearer ${authToken}`; + const res = await fetch(`${BASE}/api/export/stix/${huntId}`, { headers }); + if (!res.ok) { + const body = await res.json().catch(() => ({})); + throw new Error(body.detail || `HTTP ${res.status}`); + } + const blob = await res.blob(); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `hunt-${huntId}-stix-bundle.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }, +}; + diff --git a/frontend/src/components/AUPScanner.tsx b/frontend/src/components/AUPScanner.tsx index 952fd52..4c0d164 100644 --- a/frontend/src/components/AUPScanner.tsx +++ b/frontend/src/components/AUPScanner.tsx @@ -188,11 +188,13 @@ const RESULT_COLUMNS: GridColDef[] = [ ), }, { field: 'keyword', headerName: 'Keyword', width: 140 }, - { field: 'source_type', headerName: 'Source', width: 120 }, - { field: 'dataset_name', headerName: 'Dataset', width: 150 }, + { field: 'dataset_name', headerName: 'Dataset', width: 170 }, + { field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' }, + { field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' }, + { field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 }, { field: 'field', headerName: 'Field', width: 130 }, - { field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 }, - { field: 'row_index', headerName: 'Row #', width: 80, type: 'number' }, + { field: 'source_type', headerName: 'Source', width: 120 }, + { field: 'row_index', headerName: 'Row #', width: 90, type: 'number' }, ]; export default function AUPScanner() { @@ -210,9 +212,9 @@ export default function AUPScanner() { // Scan options const [selectedDs, setSelectedDs] = useState>(new Set()); const [selectedThemes, setSelectedThemes] = useState>(new Set()); - const [scanHunts, setScanHunts] = useState(true); - const [scanAnnotations, setScanAnnotations] = useState(true); - const [scanMessages, setScanMessages] = useState(true); + const [scanHunts, setScanHunts] = useState(false); + const [scanAnnotations, setScanAnnotations] = useState(false); + const [scanMessages, setScanMessages] = useState(false); // Load themes + hunts const loadData = useCallback(async () => { @@ -224,9 +226,13 @@ export default function AUPScanner() { ]); setThemes(tRes.themes); setHuntList(hRes.hunts); + if (!selectedHuntId && hRes.hunts.length > 0) { + const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0]; + setSelectedHuntId(best.id); + } } catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); } setLoading(false); - }, [enqueueSnackbar]); + }, [enqueueSnackbar, selectedHuntId]); useEffect(() => { loadData(); }, [loadData]); @@ -237,7 +243,7 @@ export default function AUPScanner() { datasets.list(0, 500, selectedHuntId).then(res => { if (cancelled) return; setDsList(res.datasets); - setSelectedDs(new Set(res.datasets.map(d => d.id))); + setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id))); }).catch(() => {}); return () => { cancelled = true; }; }, [selectedHuntId]); @@ -251,6 +257,15 @@ export default function AUPScanner() { // Run scan const runScan = useCallback(async () => { + if (!selectedHuntId) { + enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' }); + return; + } + if (selectedDs.size === 0) { + enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' }); + return; + } + setScanning(true); setScanResult(null); try { @@ -260,6 +275,7 @@ export default function AUPScanner() { scan_hunts: scanHunts, scan_annotations: scanAnnotations, scan_messages: scanMessages, + prefer_cache: true, }); setScanResult(res); enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, { @@ -267,7 +283,7 @@ export default function AUPScanner() { }); } catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); } setScanning(false); - }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]); + }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]); if (loading) return ; @@ -316,9 +332,38 @@ export default function AUPScanner() { )} {!selectedHuntId && ( - All datasets will be scanned if no hunt is selected + Select a hunt to enable scoped scanning )} + + + Datasets + + + + {selectedHuntId && dsList.length > 0 && ( + + + + + + )} {/* Theme selector */} @@ -372,7 +417,7 @@ export default function AUPScanner() { @@ -392,6 +437,15 @@ export default function AUPScanner() { {scanResult.total_hits} hits across{' '} {scanResult.rows_scanned} rows |{' '} {scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned + {scanResult.cache_status && ( + + )} )} diff --git a/frontend/src/components/AgentPanel.tsx b/frontend/src/components/AgentPanel.tsx index 8d8fcb1..7e8aa07 100644 --- a/frontend/src/components/AgentPanel.tsx +++ b/frontend/src/components/AgentPanel.tsx @@ -1,6 +1,6 @@ /** - * AgentPanel — analyst-assist chat with quick / deep / debate modes, - * streaming support, SANS references, and conversation persistence. + * AgentPanel - analyst-assist chat with quick / deep / debate modes, + * SSE streaming, SANS references, and conversation persistence. */ import React, { useState, useRef, useEffect, useCallback } from 'react'; @@ -8,7 +8,7 @@ import { Box, Typography, Paper, TextField, Button, Stack, Chip, ToggleButtonGroup, ToggleButton, CircularProgress, Alert, Accordion, AccordionSummary, AccordionDetails, Divider, Select, - MenuItem, FormControl, InputLabel, LinearProgress, + MenuItem, FormControl, InputLabel, LinearProgress, FormControlLabel, Switch, } from '@mui/material'; import SendIcon from '@mui/icons-material/Send'; import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; @@ -16,26 +16,31 @@ import SchoolIcon from '@mui/icons-material/School'; import PsychologyIcon from '@mui/icons-material/Psychology'; import ForumIcon from '@mui/icons-material/Forum'; import SpeedIcon from '@mui/icons-material/Speed'; +import StopIcon from '@mui/icons-material/Stop'; import { useSnackbar } from 'notistack'; import { agent, datasets, hunts, type AssistRequest, type AssistResponse, type DatasetSummary, type Hunt, } from '../api/client'; -interface Message { role: 'user' | 'assistant'; content: string; meta?: AssistResponse } +interface Message { role: 'user' | 'assistant'; content: string; meta?: AssistResponse; streaming?: boolean } export default function AgentPanel() { const { enqueueSnackbar } = useSnackbar(); const [messages, setMessages] = useState([]); const [query, setQuery] = useState(''); const [mode, setMode] = useState<'quick' | 'deep' | 'debate'>('quick'); + const [executionPreference, setExecutionPreference] = useState<'auto' | 'force' | 'off'>('auto'); + const [learningMode, setLearningMode] = useState(false); const [loading, setLoading] = useState(false); + const [streaming, setStreaming] = useState(false); const [conversationId, setConversationId] = useState(null); const [datasetList, setDatasets] = useState([]); const [huntList, setHunts] = useState([]); const [selectedDataset, setSelectedDataset] = useState(''); const [selectedHunt, setSelectedHunt] = useState(''); const bottomRef = useRef(null); + const abortRef = useRef(null); useEffect(() => { datasets.list(0, 100).then(r => setDatasets(r.datasets)).catch(() => {}); @@ -44,6 +49,12 @@ export default function AgentPanel() { useEffect(() => { bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]); + const stopStreaming = () => { + abortRef.current?.abort(); + setStreaming(false); + setLoading(false); + }; + const send = useCallback(async () => { if (!query.trim() || loading) return; const userMsg: Message = { role: 'user', content: query }; @@ -59,18 +70,118 @@ export default function AgentPanel() { hunt_id: selectedHunt || undefined, dataset_name: ds?.name, data_summary: ds ? `${ds.row_count} rows, columns: ${Object.keys(ds.column_schema || {}).join(', ')}` : undefined, + execution_preference: executionPreference, + learning_mode: learningMode, }; + // Try SSE streaming first, fall back to regular request try { - const resp = await agent.assist(req); - setConversationId(resp.conversation_id || null); - setMessages(prev => [...prev, { role: 'assistant', content: resp.guidance, meta: resp }]); - } catch (e: any) { - enqueueSnackbar(e.message, { variant: 'error' }); - setMessages(prev => [...prev, { role: 'assistant', content: `Error: ${e.message}` }]); + const controller = new AbortController(); + abortRef.current = controller; + setStreaming(true); + + const res = await agent.assistStream(req); + if (!res.ok || !res.body) throw new Error('Stream unavailable'); + + setMessages(prev => [...prev, { role: 'assistant', content: '', streaming: true }]); + + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let fullText = ''; + let metaData: AssistResponse | undefined; + + while (true) { + if (controller.signal.aborted) break; + const { done, value } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value, { stream: true }); + // Parse SSE lines + for (const line of chunk.split('\n')) { + if (line.startsWith('data: ')) { + const data = line.slice(6); + if (data === '[DONE]') continue; + try { + const parsed = JSON.parse(data); + if (parsed.token) { + fullText += parsed.token; + const nextText = fullText; + setMessages(prev => { + const updated = [...prev]; + const last = updated[updated.length - 1]; + if (last?.role === 'assistant') { + updated[updated.length - 1] = { ...last, content: nextText }; + } + return updated; + }); + } + if (parsed.meta || parsed.confidence) { + metaData = parsed.meta || parsed; + } + } catch { + // Non-JSON data line, treat as text token + fullText += data; + const nextText = fullText; + setMessages(prev => { + const updated = [...prev]; + const last = updated[updated.length - 1]; + if (last?.role === 'assistant') { + updated[updated.length - 1] = { ...last, content: nextText }; + } + return updated; + }); + } + } + } + } + + // Finalize the streamed message + setMessages(prev => { + const updated = [...prev]; + const last = updated[updated.length - 1]; + if (last?.role === 'assistant') { + updated[updated.length - 1] = { ...last, content: fullText || 'No response received.', streaming: false, meta: metaData }; + } + return updated; + }); + if (metaData?.conversation_id) setConversationId(metaData.conversation_id); + + } catch (streamErr: any) { + // Streaming failed or unavailable, fall back to regular request + setStreaming(false); + // Remove the empty streaming message if one was added + setMessages(prev => { + if (prev.length > 0 && prev[prev.length - 1].streaming && prev[prev.length - 1].content === '') { + return prev.slice(0, -1); + } + return prev; + }); + + try { + const resp = await agent.assist(req); + setConversationId(resp.conversation_id || null); + setMessages(prev => [...prev, { role: 'assistant', content: resp.guidance, meta: resp }]); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + setMessages(prev => [...prev, { role: 'assistant', content: `Error: ${e.message}` }]); + } + } finally { + setLoading(false); + setStreaming(false); + abortRef.current = null; } - setLoading(false); - }, [query, mode, loading, conversationId, selectedDataset, selectedHunt, datasetList, enqueueSnackbar]); + }, [ + query, + mode, + executionPreference, + learningMode, + loading, + conversationId, + selectedDataset, + selectedHunt, + datasetList, + enqueueSnackbar, + ]); const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send(); } @@ -112,6 +223,25 @@ export default function AgentPanel() { {huntList.map(h => {h.name})} + + + Execution + + + + setLearningMode(v)} size="small" />} + label={Learning mode} + sx={{ ml: 0.5 }} + /> @@ -124,7 +254,7 @@ export default function AgentPanel() { Ask a question about your threat hunt data. - The agent provides advisory guidance — all decisions remain with the analyst. + Agent can provide advisory guidance or execute policy scans based on execution mode. )} @@ -132,20 +262,24 @@ export default function AgentPanel() { {m.role === 'user' ? 'You' : 'Agent'} + {m.streaming && } - {m.content} + + {m.content} + {m.streaming && |} + {/* Response metadata */} {m.meta && ( - = 0.7 ? 'success' : m.meta.confidence >= 0.4 ? 'warning' : 'error'} variant="outlined" /> @@ -190,7 +324,7 @@ export default function AgentPanel() { {m.meta.sans_references.map((r, j) => ( - • {r} + {r} ))} @@ -214,6 +348,32 @@ export default function AgentPanel() { )} + {/* Execution summary */} + {m.meta.execution && ( + + }> + + Execution Results ({m.meta.execution.policy_hits} hits in {m.meta.execution.elapsed_ms}ms) + + + + + Scope: {m.meta.execution.scope} + + + Datasets: {m.meta.execution.datasets_scanned.join(', ') || 'None'} + + {m.meta.execution.top_domains.length > 0 && ( + + {m.meta.execution.top_domains.map((d, j) => ( + + ))} + + )} + + + )} + {/* Caveats */} {m.meta.caveats && ( @@ -224,7 +384,7 @@ export default function AgentPanel() { )} ))} - {loading && } + {loading && !streaming && }
@@ -237,10 +397,17 @@ export default function AgentPanel() { onKeyDown={handleKeyDown} disabled={loading} /> - + {streaming ? ( + + ) : ( + + )} ); } + diff --git a/frontend/src/components/AnalysisDashboard.tsx b/frontend/src/components/AnalysisDashboard.tsx index f79253d..3841a47 100644 --- a/frontend/src/components/AnalysisDashboard.tsx +++ b/frontend/src/components/AnalysisDashboard.tsx @@ -30,7 +30,6 @@ import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer'; import WorkIcon from '@mui/icons-material/Work'; import SendIcon from '@mui/icons-material/Send'; import StopIcon from '@mui/icons-material/Stop'; -import DeleteIcon from '@mui/icons-material/Delete'; import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import ErrorIcon from '@mui/icons-material/Error'; import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; diff --git a/frontend/src/components/CorrelationView.tsx b/frontend/src/components/CorrelationView.tsx index 4831178..ff47369 100644 --- a/frontend/src/components/CorrelationView.tsx +++ b/frontend/src/components/CorrelationView.tsx @@ -1,19 +1,25 @@ /** - * CorrelationView — cross-hunt correlation analysis with IOC, time, - * technique, and host overlap visualisation. + * CorrelationView - cross-hunt correlation analysis with recharts visualizations. + * IOC overlap bar chart, technique overlap heat chips, time/host overlap display. */ import React, { useEffect, useState, useCallback } from 'react'; import { Box, Typography, Paper, Stack, Chip, Button, CircularProgress, Alert, Table, TableBody, TableCell, TableContainer, TableHead, - TableRow, TextField, + TableRow, TextField, Grid, Divider, } from '@mui/material'; import CompareArrowsIcon from '@mui/icons-material/CompareArrows'; import SearchIcon from '@mui/icons-material/Search'; import { useSnackbar } from 'notistack'; +import { + BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip as ReTooltip, + ResponsiveContainer, PieChart, Pie, Cell, Legend, +} from 'recharts'; import { correlation, hunts, type Hunt, type CorrelationResult } from '../api/client'; +const PIE_COLORS = ['#60a5fa', '#f472b6', '#34d399', '#fbbf24', '#a78bfa', '#f87171', '#38bdf8', '#fb923c']; + export default function CorrelationView() { const { enqueueSnackbar } = useSnackbar(); const [huntList, setHuntList] = useState([]); @@ -53,6 +59,18 @@ export default function CorrelationView() { } catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); } }, [iocSearch, enqueueSnackbar]); + // Build chart data from results + const iocChartData = (result?.ioc_overlaps || []).slice(0, 20).map((o: any) => ({ + name: String(o.ioc_value).length > 20 ? String(o.ioc_value).slice(0, 20) + '...' : o.ioc_value, + hunts: (o.hunt_ids || []).length, + type: o.ioc_type || 'unknown', + })); + + const techniqueChartData = (result?.technique_overlaps || []).map((t: any) => ({ + name: t.technique || t.mitre_technique || 'unknown', + value: (t.hunt_ids || []).length || 1, + })); + return ( Cross-Hunt Correlation @@ -98,34 +116,80 @@ export default function CorrelationView() { )} - {/* Results */} + {/* Results with charts */} {result && ( - {result.summary} — {result.total_correlations} total correlation(s) across {result.hunt_ids.length} hunts + {result.summary} {result.total_correlations} correlation(s) across {result.hunt_ids.length} hunts - {/* IOC overlaps */} + {/* Symmetrical 2-column: IOC chart | Technique chart */} + + + + IOC Overlaps ({result.ioc_overlaps.length}) + + {iocChartData.length > 0 ? ( + + + + + + + + + + ) : ( + No IOC overlaps found. + )} + + + + + Technique Overlaps ({result.technique_overlaps.length}) + + {techniqueChartData.length > 0 ? ( + + + name}> + {techniqueChartData.map((_: any, i: number) => ( + + ))} + + + + + + ) : ( + No technique overlaps found. + )} + + + + + {/* IOC detail table */} {result.ioc_overlaps.length > 0 && ( - IOC Overlaps ({result.ioc_overlaps.length}) - - + IOC Detail + +
- IOC - Type - Shared Hunts + IOC + Type + Shared Hunts {result.ioc_overlaps.map((o: any, i: number) => ( - + {o.ioc_value} {(o.hunt_ids || []).map((hid: string, j: number) => ( - h.id === hid)?.name || hid} size="small" sx={{ mr: 0.5 }} /> + h.id === hid)?.name || hid.slice(0, 8)} + size="small" sx={{ mr: 0.5, mb: 0.5 }} /> ))} @@ -136,41 +200,47 @@ export default function CorrelationView() { )} - {/* Technique overlaps */} - {result.technique_overlaps.length > 0 && ( - - MITRE Technique Overlaps - - {result.technique_overlaps.map((t: any, i: number) => ( - - ))} - - - )} - - {/* Time overlaps */} - {result.time_overlaps.length > 0 && ( - - Time Overlaps - {result.time_overlaps.map((t: any, i: number) => ( - - {t.hunt_a || 'Hunt A'} ↔ {t.hunt_b || 'Hunt B'}: {t.overlap_start} — {t.overlap_end} - - ))} - - )} - - {/* Host overlaps */} - {result.host_overlaps.length > 0 && ( - - Host Overlaps - - {result.host_overlaps.map((h: any, i: number) => ( - - ))} - - - )} + {/* Symmetrical 2-column: Time overlaps | Host overlaps */} + + + + Time Overlaps ({result.time_overlaps.length}) + + {result.time_overlaps.length > 0 ? ( + + {result.time_overlaps.map((t: any, i: number) => ( + + + + + + {t.overlap_start} {t.overlap_end} + + + ))} + + ) : ( + No time overlaps found. + )} + + + + + Host Overlaps ({result.host_overlaps.length}) + + {result.host_overlaps.length > 0 ? ( + + {result.host_overlaps.map((h: any, i: number) => ( + + ))} + + ) : ( + No host overlaps found. + )} + + + )} diff --git a/frontend/src/components/Dashboard.tsx b/frontend/src/components/Dashboard.tsx index 39225a9..e0c855a 100644 --- a/frontend/src/components/Dashboard.tsx +++ b/frontend/src/components/Dashboard.tsx @@ -1,11 +1,12 @@ /** - * Dashboard — overview cards with hunt stats, node health, recent activity. + * Dashboard - overview cards with hunt stats, cluster health, recent activity. + * Symmetrical 4-column grid layout, empty-state onboarding, auto-refresh. */ -import React, { useEffect, useState } from 'react'; +import React, { useEffect, useState, useCallback } from 'react'; import { Box, Grid, Paper, Typography, Chip, CircularProgress, - Stack, Alert, + Stack, Alert, Button, Divider, } from '@mui/material'; import StorageIcon from '@mui/icons-material/Storage'; import SearchIcon from '@mui/icons-material/Search'; @@ -13,139 +14,245 @@ import SecurityIcon from '@mui/icons-material/Security'; import ScienceIcon from '@mui/icons-material/Science'; import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import ErrorIcon from '@mui/icons-material/Error'; +import UploadFileIcon from '@mui/icons-material/UploadFile'; +import RocketLaunchIcon from '@mui/icons-material/RocketLaunch'; +import RefreshIcon from '@mui/icons-material/Refresh'; +import { useNavigate } from 'react-router-dom'; import { hunts, datasets, hypotheses, agent, misc, type Hunt, type DatasetSummary, type HealthInfo } from '../api/client'; -function StatCard({ title, value, icon, color }: { title: string; value: string | number; icon: React.ReactNode; color: string }) { +const REFRESH_INTERVAL = 30_000; // 30s auto-refresh + +/* Stat Card */ + +function StatCard({ title, value, icon, color }: { + title: string; value: string | number; icon: React.ReactNode; color: string; +}) { return ( - - - {icon} - - {value} - {title} + + + {icon} + + {value} + {title} ); } +/* Node Status */ + function NodeStatus({ label, available }: { label: string; available: boolean }) { return ( {available ? - : - } - {label} + : } + {label} ); } +/* Empty State */ + +function EmptyOnboarding() { + const navigate = useNavigate(); + return ( + + + Welcome to ThreatHunt + + Get started by creating a hunt, uploading CSV artifacts, and letting the AI assist your investigation. + + + + + + + ); +} + +/* Main Dashboard */ + export default function Dashboard() { const [loading, setLoading] = useState(true); const [health, setHealth] = useState(null); const [huntList, setHunts] = useState([]); const [datasetList, setDatasets] = useState([]); const [hypoCount, setHypoCount] = useState(0); - const [apiInfo, setApiInfo] = useState<{ name: string; version: string; status: string } | null>(null); + const [apiInfo, setApiInfo] = useState<{ name?: string; version?: string; status?: string; service?: string } | null>(null); const [error, setError] = useState(''); + const [lastRefresh, setLastRefresh] = useState(new Date()); - useEffect(() => { - (async () => { - try { - const [h, ht, ds, hy, info] = await Promise.all([ - agent.health().catch(() => null), - hunts.list(0, 100).catch(() => ({ hunts: [], total: 0 })), - datasets.list(0, 100).catch(() => ({ datasets: [], total: 0 })), - hypotheses.list({ limit: 1 }).catch(() => ({ hypotheses: [], total: 0 })), - misc.root().catch(() => null), - ]); - setHealth(h); - setHunts(ht.hunts); - setDatasets(ds.datasets); - setHypoCount(hy.total); - setApiInfo(info); - } catch (e: any) { - setError(e.message); - } finally { - setLoading(false); - } - })(); + const refresh = useCallback(async () => { + try { + const [h, ht, ds, hy, info] = await Promise.all([ + agent.health().catch(() => null), + hunts.list(0, 100).catch(() => ({ hunts: [], total: 0 })), + datasets.list(0, 100).catch(() => ({ datasets: [], total: 0 })), + hypotheses.list({ limit: 1 }).catch(() => ({ hypotheses: [], total: 0 })), + misc.root().catch(() => null), + ]); + setHealth(h); + setHunts(ht.hunts); + setDatasets(ds.datasets); + setHypoCount(hy.total); + setApiInfo(info); + setLastRefresh(new Date()); + setError(''); + } catch (e: any) { + setError(e.message); + } finally { + setLoading(false); + } }, []); + // Initial load + useEffect(() => { refresh(); }, [refresh]); + + // Auto-refresh + useEffect(() => { + const timer = setInterval(refresh, REFRESH_INTERVAL); + return () => clearInterval(timer); + }, [refresh]); + if (loading) return ; if (error) return {error}; const activeHunts = huntList.filter(h => h.status === 'active').length; const totalRows = datasetList.reduce((s, d) => s + d.row_count, 0); + const isEmpty = huntList.length === 0 && datasetList.length === 0; return ( - Dashboard + + Dashboard + + + Updated {lastRefresh.toLocaleTimeString()} + + + + - {/* Stat cards */} - - + {/* Stat cards - symmetrical 4-column */} + + } color="#60a5fa" /> - + } color="#f472b6" /> - + } color="#10b981" /> - + } color="#f59e0b" /> - {/* Node health + API info */} - - - - LLM Cluster Health - - - - - - - - - - API Status - - - {apiInfo ? `${apiInfo.name} — ${apiInfo.version}` : 'Unreachable'} - - - Status: {apiInfo?.status ?? 'unknown'} - - - - - + {/* Empty state or content */} + {isEmpty ? ( + + ) : ( + <> + {/* Symmetrical 2-column: Cluster Health | API Status */} + + + + LLM Cluster Health + + + + + + + + + + + API Status + + + + Service + + {apiInfo?.service || apiInfo?.name || 'ThreatHunt API'} + + + + Version + + + + Status + + + + Hunts + {huntList.length} total ({activeHunts} active) + + + + + - {/* Recent hunts */} - {huntList.length > 0 && ( - - Recent Hunts - - {huntList.slice(0, 5).map(h => ( - - - {h.name} - - {h.dataset_count} datasets · {h.hypothesis_count} hypotheses - - - ))} - - + {/* Symmetrical 2-column: Recent Hunts | Recent Datasets */} + + + + Recent Hunts + + {huntList.length === 0 ? ( + No hunts yet. + ) : ( + + {huntList.slice(0, 5).map(h => ( + + + {h.name} + + {h.dataset_count}ds {h.hypothesis_count}hyp + + + ))} + + )} + + + + + Recent Datasets + + {datasetList.length === 0 ? ( + No datasets yet. + ) : ( + + {datasetList.slice(0, 5).map(d => ( + + + {d.name} + + {d.row_count.toLocaleString()} rows + + + ))} + + )} + + + + )} ); diff --git a/frontend/src/components/FileUpload.tsx b/frontend/src/components/FileUpload.tsx index 19f4c66..7dcf5cc 100644 --- a/frontend/src/components/FileUpload.tsx +++ b/frontend/src/components/FileUpload.tsx @@ -2,7 +2,7 @@ * FileUpload — multi-file drag-and-drop CSV upload with per-file progress bars. */ -import React, { useState, useCallback, useRef } from 'react'; +import React, { useState, useCallback, useRef, useEffect } from 'react'; import { Box, Typography, Paper, Stack, Chip, LinearProgress, Select, MenuItem, FormControl, InputLabel, IconButton, Tooltip, @@ -12,7 +12,7 @@ import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import ErrorIcon from '@mui/icons-material/Error'; import ClearIcon from '@mui/icons-material/Clear'; import { useSnackbar } from 'notistack'; -import { datasets, hunts, type UploadResult, type Hunt } from '../api/client'; +import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client'; interface FileJob { file: File; @@ -28,6 +28,7 @@ export default function FileUpload() { const [jobs, setJobs] = useState([]); const [huntList, setHuntList] = useState([]); const [huntId, setHuntId] = useState(''); + const [huntProgress, setHuntProgress] = useState(null); const fileRef = useRef(null); const busyRef = useRef(false); @@ -35,6 +36,28 @@ export default function FileUpload() { hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {}); }, []); + useEffect(() => { + let timer: any = null; + let cancelled = false; + + const pull = async () => { + if (!huntId) { + if (!cancelled) setHuntProgress(null); + return; + } + try { + const p = await hunts.progress(huntId); + if (!cancelled) setHuntProgress(p); + } catch { + if (!cancelled) setHuntProgress(null); + } + }; + + pull(); + if (huntId) timer = setInterval(pull, 2000); + return () => { cancelled = true; if (timer) clearInterval(timer); }; + }, [huntId, jobs.length]); + // Process the queue sequentially const processQueue = useCallback(async (queue: FileJob[]) => { if (busyRef.current) return; @@ -163,6 +186,37 @@ export default function FileUpload() { )} + {huntId && huntProgress && ( + + + + Master Processing Progress + + + + + {huntProgress.progress_percent.toFixed(1)}% + + + + + + + + + + + )} + {/* Per-file progress list */} {jobs.map((job, i) => ( diff --git a/frontend/src/components/MitreMatrix.tsx b/frontend/src/components/MitreMatrix.tsx new file mode 100644 index 0000000..46f71f9 --- /dev/null +++ b/frontend/src/components/MitreMatrix.tsx @@ -0,0 +1,189 @@ +/** + * MitreMatrix Interactive MITRE ATT&CK technique heat map. + * Aggregates detected techniques from triage, host profiles, and hypotheses. + */ +import React, { useState, useEffect, useCallback } from 'react'; +import { + Box, Typography, Paper, CircularProgress, Alert, Chip, Tooltip, + FormControl, InputLabel, Select, MenuItem, IconButton, Button, Dialog, + DialogTitle, DialogContent, List, ListItem, ListItemText, Divider, +} from '@mui/material'; +import RefreshIcon from '@mui/icons-material/Refresh'; +import DownloadIcon from '@mui/icons-material/Download'; +import { useSnackbar } from 'notistack'; +import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip as ReTooltip, ResponsiveContainer, Cell } from 'recharts'; +import { mitre, MitreCoverage, MitreTechnique, hunts, Hunt, stixExport } from '../api/client'; + +const TACTIC_COLORS: Record = { + 'Reconnaissance': '#7c3aed', + 'Resource Development': '#6d28d9', + 'Initial Access': '#ef4444', + 'Execution': '#f97316', + 'Persistence': '#f59e0b', + 'Privilege Escalation': '#eab308', + 'Defense Evasion': '#84cc16', + 'Credential Access': '#22c55e', + 'Discovery': '#14b8a6', + 'Lateral Movement': '#06b6d4', + 'Collection': '#3b82f6', + 'Command and Control': '#6366f1', + 'Exfiltration': '#a855f7', + 'Impact': '#ec4899', +}; + +export default function MitreMatrix() { + const { enqueueSnackbar } = useSnackbar(); + const [loading, setLoading] = useState(false); + const [data, setData] = useState(null); + const [huntList, setHuntList] = useState([]); + const [selectedHunt, setSelectedHunt] = useState(''); + const [detailTech, setDetailTech] = useState(null); + const [exporting, setExporting] = useState(false); + + const handleStixExport = async () => { + if (!selectedHunt) { enqueueSnackbar('Select a hunt to export STIX bundle', { variant: 'info' }); return; } + setExporting(true); + try { + await stixExport.download(selectedHunt); + enqueueSnackbar('STIX bundle downloaded', { variant: 'success' }); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setExporting(false); + } + }; + + const load = useCallback(async () => { + setLoading(true); + try { + const [coverage, h] = await Promise.all([ + mitre.coverage(selectedHunt || undefined), + huntList.length ? Promise.resolve({ hunts: huntList, total: huntList.length }) : hunts.list(0, 100), + ]); + setData(coverage); + if (!huntList.length) setHuntList(h.hunts); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setLoading(false); + } + }, [selectedHunt, huntList, enqueueSnackbar]); + + useEffect(() => { load(); }, [load]); + + const chartData = data ? Object.entries(data.tactic_coverage).map(([tactic, info]) => ({ + tactic: tactic.replace(/ /g, '\n'), + fullTactic: tactic, + count: info.count, + color: TACTIC_COLORS[tactic] || '#64748b', + })) : []; + + return ( + + + MITRE ATT&CK Coverage + + Filter by Hunt + + + + + {data && ( + <> + + + + )} + + + {loading && } + + {!loading && data && data.technique_count === 0 && ( + + No MITRE techniques detected yet. Run triage, host profiling, or add hypotheses with technique IDs to populate this view. + + )} + + {!loading && data && data.technique_count > 0 && ( + <> + {/* Bar chart of technique counts per tactic */} + + Techniques by Tactic + + + + + + + + {chartData.map((entry, i) => )} + + + + + + {/* Heat map grid */} + + Technique Matrix + + {data.tactics.map(tactic => { + const info = data.tactic_coverage[tactic]; + const techs = info?.techniques || []; + return ( + + + {tactic} + + + {techs.map(tech => ( + + setDetailTech(tech)} + sx={{ + fontSize: '0.65rem', height: 22, + bgcolor: tech.count >= 3 ? 'error.dark' : tech.count >= 2 ? 'warning.dark' : 'primary.dark', + cursor: 'pointer', '&:hover': { opacity: 0.8 }, + }} + /> + + ))} + {!techs.length && } + + + ); + })} + + + + )} + + {/* Detail dialog */} + setDetailTech(null)} maxWidth="sm" fullWidth> + {detailTech?.id} {detailTech?.tactic} + + Detected {detailTech?.count} time(s) from: + + {detailTech?.sources.map((s, i) => ( + + + + + {i < (detailTech?.sources.length || 0) - 1 && } + + ))} + + + + + ); +} + + diff --git a/frontend/src/components/NetworkMap.tsx b/frontend/src/components/NetworkMap.tsx index 7ede886..2b9d54e 100644 --- a/frontend/src/components/NetworkMap.tsx +++ b/frontend/src/components/NetworkMap.tsx @@ -14,6 +14,9 @@ * - Node drag with springy neighbor physics * - Glassmorphism toolbar + floating legend overlay * - Rich popover: hostname, IP, OS, users, datasets + * - MODULE-LEVEL CACHE: graph survives tab switches + * - AUTO-LOAD: picks most recent hunt on mount + * - FULL VIEWPORT canvas: fills available space * - Zero extra npm dependencies */ @@ -59,6 +62,8 @@ interface GNode { interface GEdge { source: string; target: string; weight: number } interface Graph { nodes: GNode[]; edges: GEdge[] } +type LabelMode = 'all' | 'highlight' | 'none'; + const TYPE_COLORS: Record = { host: '#60a5fa', external_ip: '#fbbf24', @@ -68,6 +73,18 @@ const GLOW_COLORS: Record = { external_ip: 'rgba(251,191,36,0.35)', }; +// ========================================================================= +// MODULE-LEVEL CACHE - survives unmount/remount on tab switches +// ========================================================================= +const graphCache = new Map(); +let lastSelectedHuntId = ''; +const LARGE_HUNT_HOST_THRESHOLD = 400; +const LARGE_HUNT_SUBGRAPH_HOSTS = 220; +const LARGE_HUNT_SUBGRAPH_EDGES = 1200; +const RENDER_SIMPLIFY_NODE_THRESHOLD = 120; +const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500; +const EDGE_DRAW_TARGET = 600; + // == Build graph from inventory ========================================== function buildGraphFromInventory( @@ -75,53 +92,67 @@ function buildGraphFromInventory( canvasW: number, canvasH: number, ): Graph { const nodeMap = new Map(); + const cx = canvasW / 2, cy = canvasH / 2; + const MAX_EXTERNAL_NODES = 30; - // Create host nodes for (const h of hosts) { const r = Math.max(8, Math.min(26, 6 + Math.sqrt(h.row_count / 100) * 3)); nodeMap.set(h.id, { id: h.id, label: h.hostname || h.fqdn || h.client_id, - x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75, - y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65, + x: cx + (Math.random() - 0.5) * canvasW * 0.75, + y: cy + (Math.random() - 0.5) * canvasH * 0.65, vx: 0, vy: 0, radius: r, color: TYPE_COLORS.host, count: h.row_count, meta: { type: 'host' as NodeType, - hostname: h.hostname, - fqdn: h.fqdn, - client_id: h.client_id, - ips: h.ips, - os: h.os, - users: h.users, - datasets: h.datasets, - row_count: h.row_count, + hostname: h.hostname, fqdn: h.fqdn, client_id: h.client_id, + ips: h.ips, os: h.os, users: h.users, + datasets: h.datasets, row_count: h.row_count, }, }); } - // Create edges + external IP nodes (for unresolved remote IPs) - const edges: GEdge[] = []; + const extIpCounts = new Map(); + const extIpLabel = new Map(); for (const c of connections) { if (!nodeMap.has(c.target)) { - nodeMap.set(c.target, { - id: c.target, - label: c.target_ip || c.target, - x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75, - y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65, - vx: 0, vy: 0, radius: 6, - color: TYPE_COLORS.external_ip, - count: c.count, - meta: { - type: 'external_ip' as NodeType, - hostname: '', fqdn: '', client_id: '', - ips: [c.target_ip || c.target], - os: '', users: [], datasets: [], row_count: 0, - }, - }); + extIpCounts.set(c.target, (extIpCounts.get(c.target) || 0) + c.count); + if (!extIpLabel.has(c.target)) extIpLabel.set(c.target, c.target_ip || c.target); + } + } + const topExternal = new Set( + [...extIpCounts.entries()] + .sort((a, b) => b[1] - a[1]) + .slice(0, MAX_EXTERNAL_NODES) + .map(e => e[0]) + ); + + for (const [id, totalCount] of extIpCounts) { + if (!topExternal.has(id)) continue; + nodeMap.set(id, { + id, + label: extIpLabel.get(id) || id, + x: cx + (Math.random() - 0.5) * canvasW * 0.75, + y: cy + (Math.random() - 0.5) * canvasH * 0.65, + vx: 0, vy: 0, radius: 6, + color: TYPE_COLORS.external_ip, + count: totalCount, + meta: { + type: 'external_ip' as NodeType, + hostname: '', fqdn: '', client_id: '', + ips: [extIpLabel.get(id) || id], + os: '', users: [], datasets: [], row_count: 0, + }, + }); + } + + const edges: GEdge[] = []; + for (const c of connections) { + if (nodeMap.has(c.source) && nodeMap.has(c.target)) { + edges.push({ source: c.source, target: c.target, weight: c.count }); } - edges.push({ source: c.source, target: c.target, weight: c.count }); } return { nodes: [...nodeMap.values()], edges }; @@ -135,17 +166,39 @@ function simulationStep(graph: Graph, cx: number, cy: number, alpha: number) { const k = 120; const repulsion = 12000; const damping = 0.82; + const N = nodes.length; - for (let i = 0; i < nodes.length; i++) { - for (let j = i + 1; j < nodes.length; j++) { - const a = nodes[i], b = nodes[j]; - if (a.pinned && b.pinned) continue; - const dx = b.x - a.x, dy = b.y - a.y; - const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); - const force = (repulsion * alpha) / (dist * dist); - const fx = (dx / dist) * force, fy = (dy / dist) * force; - if (!a.pinned) { a.vx -= fx; a.vy -= fy; } - if (!b.pinned) { b.vx += fx; b.vy += fy; } + const SAMPLE_THRESHOLD = 150; + if (N > SAMPLE_THRESHOLD) { + const sampleSize = Math.min(40, Math.ceil(N * 0.15)); + const scaleFactor = N / sampleSize; + for (let i = 0; i < N; i++) { + const a = nodes[i]; + if (a.pinned) continue; + for (let s = 0; s < sampleSize; s++) { + const j = Math.floor(Math.random() * N); + if (j === i) continue; + const b = nodes[j]; + const dx = b.x - a.x, dy = b.y - a.y; + const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); + if (dist > 600) continue; + const force = (repulsion * alpha * scaleFactor) / (dist * dist); + const fx = (dx / dist) * force, fy = (dy / dist) * force; + a.vx -= fx; a.vy -= fy; + } + } + } else { + for (let i = 0; i < N; i++) { + for (let j = i + 1; j < N; j++) { + const a = nodes[i], b = nodes[j]; + if (a.pinned && b.pinned) continue; + const dx = b.x - a.x, dy = b.y - a.y; + const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy)); + const force = (repulsion * alpha) / (dist * dist); + const fx = (dx / dist) * force, fy = (dy / dist) * force; + if (!a.pinned) { a.vx -= fx; a.vy -= fy; } + if (!b.pinned) { b.vx += fx; b.vy += fy; } + } } } for (const e of edges) { @@ -188,23 +241,26 @@ const GRID_SPACING = 32; function drawBackground( ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number, + simplify: boolean, ) { ctx.fillStyle = BG_COLOR; ctx.fillRect(0, 0, w, h); - ctx.save(); - ctx.translate(vp.x * dpr, vp.y * dpr); - ctx.scale(vp.scale * dpr, vp.scale * dpr); - const startX = -vp.x / vp.scale - GRID_SPACING; - const startY = -vp.y / vp.scale - GRID_SPACING; - const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; - const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; - ctx.fillStyle = GRID_DOT_COLOR; - for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { - for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { - ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + if (!simplify) { + ctx.save(); + ctx.translate(vp.x * dpr, vp.y * dpr); + ctx.scale(vp.scale * dpr, vp.scale * dpr); + const startX = -vp.x / vp.scale - GRID_SPACING; + const startY = -vp.y / vp.scale - GRID_SPACING; + const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2; + const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2; + ctx.fillStyle = GRID_DOT_COLOR; + for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) { + for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) { + ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill(); + } } + ctx.restore(); } - ctx.restore(); const vignette = ctx.createRadialGradient(w / 2, h / 2, w * 0.2, w / 2, h / 2, w * 0.7); vignette.addColorStop(0, 'rgba(10,16,30,0)'); vignette.addColorStop(1, 'rgba(10,16,30,0.55)'); @@ -216,8 +272,11 @@ function drawEdges( ctx: CanvasRenderingContext2D, graph: Graph, hovered: string | null, selected: string | null, nodeMap: Map, animTime: number, + simplify: boolean, ) { - for (const e of graph.edges) { + const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1; + for (let ei = 0; ei < graph.edges.length; ei += edgeStep) { + const e = graph.edges[ei]; const a = nodeMap.get(e.source), b = nodeMap.get(e.target); if (!a || !b) continue; const isActive = (hovered && (e.source === hovered || e.target === hovered)) @@ -229,7 +288,7 @@ function drawEdges( const cpx = mx + (-dy / (len || 1)) * perpScale; const cpy = my + (dx / (len || 1)) * perpScale; - ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); + ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); } if (isActive) { ctx.strokeStyle = 'rgba(96,165,250,0.8)'; @@ -240,11 +299,11 @@ function drawEdges( ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8; ctx.strokeStyle = 'rgba(96,165,250,0.3)'; ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2); - ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); + ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); } ctx.stroke(); ctx.restore(); } else { const alpha = Math.min(0.35, 0.08 + e.weight * 0.01); - ctx.strokeStyle = `rgba(100,116,139,${alpha})`; + ctx.strokeStyle = 'rgba(100,116,139,' + alpha + ')'; ctx.lineWidth = Math.min(2.5, 0.4 + e.weight * 0.08); ctx.stroke(); } @@ -302,10 +361,16 @@ function drawLabels( ctx: CanvasRenderingContext2D, graph: Graph, hovered: string | null, selected: string | null, search: string, matchSet: Set, vp: Viewport, + simplify: boolean, labelMode: LabelMode, ) { + if (labelMode === 'none') return; const dimmed = search.length > 0; + if (labelMode === 'highlight' && !search && !hovered && !selected) return; + if (simplify && labelMode !== 'all' && !search && !hovered && !selected) { + return; + } const fontSize = Math.max(9, Math.round(12 / vp.scale)); - ctx.font = `500 ${fontSize}px Inter, system-ui, sans-serif`; + ctx.font = '500 ' + fontSize + 'px Inter, system-ui, sans-serif'; ctx.textAlign = 'center'; ctx.textBaseline = 'bottom'; @@ -318,13 +383,13 @@ function drawLabels( for (const n of sorted) { const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id); - // Always show labels for hosts (since they're deduped and fewer) - const show = isHighlight || n.meta.type === 'host' || n.count >= 2; + const show = labelMode === 'all' + ? (isHighlight || n.meta.type === 'host' || n.count >= 2) + : isHighlight; if (!show) continue; const isDim = dimmed && !matchSet.has(n.id); if (isDim) continue; - // Two-line label: hostname + IP (if available) const line1 = n.label; const line2 = n.meta.ips.length > 0 ? n.meta.ips[0] : ''; const tw = Math.max(ctx.measureText(line1).width, line2 ? ctx.measureText(line2).width : 0); @@ -355,11 +420,9 @@ function drawLabels( ctx.lineWidth = 0.8; ctx.stroke(); ctx.restore(); - // Hostname line ctx.fillStyle = isHighlight ? '#ffffff' : n.color; ctx.globalAlpha = isHighlight ? 1 : 0.85; ctx.fillText(line1, lx, ly - (line2 ? fontSize * 0.5 : 0)); - // IP line (smaller, dimmer) if (line2) { ctx.fillStyle = 'rgba(148,163,184,0.6)'; ctx.fillText(line2, lx, ly + fontSize * 0.5); @@ -371,7 +434,7 @@ function drawLabels( function drawGraph( ctx: CanvasRenderingContext2D, graph: Graph, hovered: string | null, selected: string | null, search: string, - vp: Viewport, animTime: number, dpr: number, + vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode, ) { const w = ctx.canvas.width, h = ctx.canvas.height; const nodeMap = new Map(graph.nodes.map(n => [n.id, n])); @@ -386,18 +449,36 @@ function drawGraph( ) matchSet.add(n.id); } } - drawBackground(ctx, w, h, vp, dpr); + const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD; + drawBackground(ctx, w, h, vp, dpr, simplify); ctx.save(); ctx.translate(vp.x * dpr, vp.y * dpr); ctx.scale(vp.scale * dpr, vp.scale * dpr); - drawEdges(ctx, graph, hovered, selected, nodeMap, animTime); + drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify); drawNodes(ctx, graph, hovered, selected, search, matchSet); - drawLabels(ctx, graph, hovered, selected, search, matchSet, vp); + drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode); ctx.restore(); } // == Hit-test ============================================================= +function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean { + const fontSize = Math.max(9, Math.round(12 / vp.scale)); + const approxCharW = Math.max(5, fontSize * 0.58); + const line1 = node.label || ''; + const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : ''; + const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0); + const px = 5, py = 2; + const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2; + const lx = node.x, ly = node.y - node.radius - 6; + const rx = lx - tw / 2 - px; + const ry = ly - totalH; + const rw = tw + px * 2; + const rh = totalH; + return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh); +} + + function screenToWorld( canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, ): { wx: number; wy: number } { @@ -409,19 +490,56 @@ function hitTest( graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport, ): GNode | null { const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp); + + // Node-circle hit has priority for (const n of graph.nodes) { const dx = n.x - wx, dy = n.y - wy; if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n; } + + // Then label hit (so clicking text works too on manageable graph sizes) + if (graph.nodes.length <= 220) { + for (const n of graph.nodes) { + if (isPointOnNodeLabel(n, wx, wy, vp)) return n; + } + } + return null; } + +// == Auto-fit: center graph and zoom to fit all nodes ==================== + +function fitGraphToCanvas(graph: Graph, canvasW: number, canvasH: number): Viewport { + if (graph.nodes.length === 0) return { x: 0, y: 0, scale: 1 }; + let minX = Infinity, minY = Infinity, maxX = -Infinity, maxY = -Infinity; + for (const n of graph.nodes) { + minX = Math.min(minX, n.x - n.radius); + minY = Math.min(minY, n.y - n.radius); + maxX = Math.max(maxX, n.x + n.radius); + maxY = Math.max(maxY, n.y + n.radius); + } + const graphW = maxX - minX || 1; + const graphH = maxY - minY || 1; + const pad = 80; + const scaleX = (canvasW - pad * 2) / graphW; + const scaleY = (canvasH - pad * 2) / graphH; + const scale = Math.min(scaleX, scaleY, 2.5); + const cx = (minX + maxX) / 2; + const cy = (minY + maxY) / 2; + return { + x: canvasW / 2 - cx * scale, + y: canvasH / 2 - cy * scale, + scale, + }; +} + // == Component ============================================================= export default function NetworkMap() { const theme = useTheme(); const [huntList, setHuntList] = useState([]); - const [selectedHuntId, setSelectedHuntId] = useState(''); + const [selectedHuntId, setSelectedHuntId] = useState(lastSelectedHuntId); const [loading, setLoading] = useState(false); const [progress, setProgress] = useState(''); @@ -431,10 +549,14 @@ export default function NetworkMap() { const [hovered, setHovered] = useState(null); const [selectedNode, setSelectedNode] = useState(null); const [search, setSearch] = useState(''); + const [labelMode, setLabelMode] = useState('highlight'); const canvasRef = useRef(null); const wrapperRef = useRef(null); - const [canvasSize, setCanvasSize] = useState({ w: 900, h: 600 }); + const [canvasSize, setCanvasSize] = useState({ w: 1200, h: 800 }); + + // Ref mirror of canvasSize - lets loadGraph read current size without depending on it + const canvasSizeRef = useRef({ w: 1200, h: 800 }); const vpRef = useRef({ x: 0, y: 0, scale: 1 }); const [vpScale, setVpScale] = useState(1); @@ -450,31 +572,118 @@ export default function NetworkMap() { const selectedNodeRef = useRef(null); const searchRef = useRef(''); const graphRef = useRef(null); + const hoverRafRef = useRef(0); const [popoverAnchor, setPopoverAnchor] = useState<{ top: number; left: number } | null>(null); useEffect(() => { hoveredRef.current = hovered; }, [hovered]); useEffect(() => { selectedNodeRef.current = selectedNode; }, [selectedNode]); useEffect(() => { searchRef.current = search; }, [search]); + useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize, labelMode]); - // Load hunts on mount - useEffect(() => { - hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {}); + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => { + setLoading(true); setError(''); setGraph(null); setStats(null); + setSelectedNode(null); setPopoverAnchor(null); + + const waitReadyThen = async (fn: () => Promise): Promise => { + let delayMs = 1500; + const startedAt = Date.now(); + for (;;) { + const out: any = await fn(); + if (out && !out.status) return out as T; + const st = await network.inventoryStatus(huntId); + if (st.status === 'ready') { + const out2: any = await fn(); + if (out2 && !out2.status) return out2 as T; + } + if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes'); + const jitter = Math.floor(Math.random() * 250); + await sleep(delayMs + jitter); + delayMs = Math.min(10000, Math.floor(delayMs * 1.5)); + } + }; + + try { + setProgress('Loading network summary'); + const summary: any = await waitReadyThen(() => network.summary(huntId, 20)); + const totalHosts = summary?.stats?.total_hosts || 0; + + if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) { + setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`); + const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES)); + if (!sub?.hosts || sub.hosts.length === 0) { + setError('No hosts found for subgraph.'); + return; + } + const { w, h } = canvasSizeRef.current; + const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h); + simulate(g, w / 2, h / 2, 20); + simAlphaRef.current = 0.3; + setStats(summary.stats); + graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() }); + setGraph(g); + return; + } + + // Small/medium hunts: load full inventory + setProgress('Loading host inventory'); + const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh)); + if (!inv?.hosts || inv.hosts.length === 0) { + setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.'); + return; + } + const { w, h } = canvasSizeRef.current; + const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h); + simulate(g, w / 2, h / 2, 30); + simAlphaRef.current = 0.3; + setStats(summary.stats || inv.stats); + graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() }); + setGraph(g); + } catch (e: any) { + console.error('[NetworkMap] scale-aware load error:', e); + setError(e.message || 'Failed to load network data'); + } finally { + setLoading(false); + setProgress(''); + } }, []); - // Resize observer + // Persist selected hunt across tab switches + useEffect(() => { lastSelectedHuntId = selectedHuntId; }, [selectedHuntId]); + + // Load hunts on mount + auto-select + useEffect(() => { + hunts.list(0, 200).then(r => { + setHuntList(r.hunts); + // Auto-select: restore last hunt, or pick first with datasets + if (!selectedHuntId && r.hunts.length > 0) { + const best = r.hunts.find(h => h.dataset_count > 0) || r.hunts[0]; + setSelectedHuntId(best.id); + } + }).catch(() => {}); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + // Resize observer - FILL available viewport useEffect(() => { const el = wrapperRef.current; if (!el) return; - const ro = new ResizeObserver(entries => { - for (const entry of entries) { - const w = Math.round(entry.contentRect.width); - if (w > 100) setCanvasSize({ w, h: Math.max(500, Math.round(w * 0.56)) }); - } - }); + const updateSize = () => { + const rect = el.getBoundingClientRect(); + const w = Math.round(rect.width); + // Fill to bottom of viewport with 16px margin + const h = Math.max(500, Math.round(window.innerHeight - rect.top - 16)); + if (w > 100) setCanvasSize({ w, h }); + }; + updateSize(); + const ro = new ResizeObserver(updateSize); ro.observe(el); - return () => ro.disconnect(); - }, []); + window.addEventListener('resize', updateSize); + return () => { ro.disconnect(); window.removeEventListener('resize', updateSize); }; + // Re-run when graph or loading changes so we catch the element appearing + }, [graph, loading]); // HiDPI canvas sizing useEffect(() => { @@ -485,44 +694,120 @@ export default function NetworkMap() { canvas.height = canvasSize.h * dpr; canvas.style.width = canvasSize.w + 'px'; canvas.style.height = canvasSize.h + 'px'; - }, [canvasSize]); + }, [canvasSize, labelMode]); - // Load host inventory for selected hunt - const loadGraph = useCallback(async (huntId: string) => { + // Load graph data for selected hunt (delegates to scale-aware loader). + const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => { if (!huntId) return; - setLoading(true); setError(''); setGraph(null); setStats(null); - setSelectedNode(null); setPopoverAnchor(null); - try { - setProgress('Building host inventory (scanning all datasets)\u2026'); - const inv = await network.hostInventory(huntId); - setStats(inv.stats); - if (inv.hosts.length === 0) { - setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.'); - setLoading(false); setProgress(''); + // Check module-level cache first (5 min TTL) + if (!forceRefresh) { + const cached = graphCache.get(huntId); + if (cached && Date.now() - cached.ts < 5 * 60 * 1000) { + setGraph(cached.graph); + setStats(cached.stats); + setError(''); + simAlphaRef.current = 0; + return; + } + } + + await loadScaleAwareGraph(huntId, forceRefresh); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); // Stable - reads canvasSizeRef, no state deps + + // Single master effect: when hunt changes, check backend status, poll if building, then load + useEffect(() => { + if (!selectedHuntId) return; + let cancelled = false; + + const waitUntilReady = async (): Promise => { + // Poll inventory-status with exponential backoff until 'ready' (or cancelled) + setProgress('Host inventory is being prepared in the background'); + setLoading(true); + let delayMs = 1500; + const startedAt = Date.now(); + for (;;) { + const jitter = Math.floor(Math.random() * 250); + await sleep(delayMs + jitter); + if (cancelled) return false; + try { + const st = await network.inventoryStatus(selectedHuntId); + if (cancelled) return false; + if (st.status === 'ready') return true; + if (Date.now() - startedAt > 5 * 60 * 1000) { + setError('Host inventory build timed out. Please retry.'); + return false; + } + delayMs = Math.min(10000, Math.floor(delayMs * 1.5)); + // still building or none (job may not have started yet) - keep polling + } catch { + if (cancelled) return false; + delayMs = Math.min(10000, Math.floor(delayMs * 1.5)); + } + } + }; + + const run = async () => { + // Check module-level JS cache first (instant) + const cached = graphCache.get(selectedHuntId); + if (cached && Date.now() - cached.ts < 5 * 60 * 1000) { + setGraph(cached.graph); + setStats(cached.stats); + setError(''); + simAlphaRef.current = 0; return; } - setProgress(`Building graph for ${inv.stats.total_hosts} hosts\u2026`); - const g = buildGraphFromInventory(inv.hosts, inv.connections, canvasSize.w, canvasSize.h); - simulate(g, canvasSize.w / 2, canvasSize.h / 2, 30); - simAlphaRef.current = 0.8; - setGraph(g); - } catch (e: any) { setError(e.message); } - setLoading(false); setProgress(''); - }, [canvasSize]); + try { + // Ask backend if inventory is ready, building, or cold + const st = await network.inventoryStatus(selectedHuntId); + if (cancelled) return; - useEffect(() => { - if (selectedHuntId) loadGraph(selectedHuntId); + if (st.status === 'ready') { + // Instant load from backend cache + await loadGraph(selectedHuntId); + return; + } + + if (st.status === 'none') { + // Cold cache: trigger a background build via host-inventory (returns 202) + try { await network.hostInventory(selectedHuntId); } catch { /* 202 or error, don't care */ } + } + + // Wait for build to finish (covers both 'building' and 'none' -> just triggered) + const ready = await waitUntilReady(); + if (cancelled || !ready) return; + + // Now load the freshly cached data + await loadGraph(selectedHuntId); + } catch (e: any) { + if (!cancelled) { + console.error('[NetworkMap] status/load error:', e); + setError(e.message || 'Failed to load network inventory'); + setLoading(false); + setProgress(''); + } + } + }; + + run(); + return () => { cancelled = true; }; }, [selectedHuntId, loadGraph]); + // Auto-fit viewport when graph loads useEffect(() => { - vpRef.current = { x: 0, y: 0, scale: 1 }; - setVpScale(1); + if (graph) { + const vp = fitGraphToCanvas(graph, canvasSize.w, canvasSize.h); + vpRef.current = vp; + setVpScale(vp.scale); + } + // eslint-disable-next-line react-hooks/exhaustive-deps }, [graph]); useEffect(() => { graphRef.current = graph; }, [graph]); + // Animation loop const startAnimLoop = useCallback(() => { if (isAnimatingRef.current) return; @@ -538,10 +823,10 @@ export default function NetworkMap() { if (simAlphaRef.current > 0.01) { simulationStep(g, canvasSize.w / 2, canvasSize.h / 2, simAlphaRef.current); - simAlphaRef.current *= 0.97; + simAlphaRef.current *= 0.93; if (simAlphaRef.current < 0.01) simAlphaRef.current = 0; } - drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr); + drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode); const needsAnim = simAlphaRef.current > 0.01 || hoveredRef.current !== null @@ -554,24 +839,28 @@ export default function NetworkMap() { } }; animFrameRef.current = requestAnimationFrame(tick); - }, [canvasSize]); + }, [canvasSize, labelMode]); useEffect(() => { if (graph) startAnimLoop(); - return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; }; + return () => { + cancelAnimationFrame(animFrameRef.current); + cancelAnimationFrame(hoverRafRef.current); + isAnimatingRef.current = false; + }; }, [graph, startAnimLoop]); - useEffect(() => { startAnimLoop(); }, [hovered, selectedNode, search, startAnimLoop]); - const redraw = useCallback(() => { if (!graph || !canvasRef.current) return; const ctx = canvasRef.current.getContext('2d'); const dpr = window.devicePixelRatio || 1; - if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr); - }, [graph, hovered, selectedNode, search]); + if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode); + }, [graph, hovered, selectedNode, search, labelMode]); useEffect(() => { if (!isAnimatingRef.current) redraw(); }, [redraw]); + useEffect(() => { if (!isAnimatingRef.current) redraw(); }, [hovered, selectedNode, search, redraw]); + // Mouse wheel -> zoom useEffect(() => { const canvas = canvasRef.current; @@ -617,8 +906,13 @@ export default function NetworkMap() { panStart.current = { x: e.clientX, y: e.clientY }; redraw(); return; } - const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current); - setHovered(node?.id ?? null); + cancelAnimationFrame(hoverRafRef.current); + const clientX = e.clientX; + const clientY = e.clientY; + hoverRafRef.current = requestAnimationFrame(() => { + const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current); + setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null))); + }); }, [graph, redraw, startAnimLoop]); const onMouseUp = useCallback(() => { dragNode.current = null; isPanning.current = false; }, []); @@ -649,28 +943,37 @@ export default function NetworkMap() { vp.scale = newScale; setVpScale(newScale); redraw(); }, [canvasSize, redraw]); - const resetView = useCallback(() => { - vpRef.current = { x: 0, y: 0, scale: 1 }; setVpScale(1); redraw(); - }, [redraw]); + const fitView = useCallback(() => { + if (!graph) return; + const vp = fitGraphToCanvas(graph, canvasSize.w, canvasSize.h); + vpRef.current = vp; setVpScale(vp.scale); redraw(); + }, [graph, canvasSize, redraw]); const connectionCount = selectedNode && graph ? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length : 0; + const nodeById = useMemo(() => { + const m = new Map(); + if (!graph) return m; + for (const n of graph.nodes) m.set(n.id, n); + return m; + }, [graph]); + const connectedNodes = useMemo(() => { if (!selectedNode || !graph) return []; const neighbors: { id: string; type: NodeType; weight: number }[] = []; for (const e of graph.edges) { if (e.source === selectedNode.id) { - const n = graph.nodes.find(x => x.id === e.target); + const n = nodeById.get(e.target); if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); } else if (e.target === selectedNode.id) { - const n = graph.nodes.find(x => x.id === e.source); + const n = nodeById.get(e.source); if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight }); } } return neighbors.sort((a, b) => b.weight - a.weight).slice(0, 12); - }, [selectedNode, graph]); + }, [selectedNode, graph, nodeById]); const hostCount = graph ? graph.nodes.filter(n => n.meta.type === 'host').length : 0; const extCount = graph ? graph.nodes.filter(n => n.meta.type === 'external_ip').length : 0; @@ -681,16 +984,17 @@ export default function NetworkMap() { if (hovered) return 'pointer'; return 'grab'; }; + // == Render ============================================================== return ( - + {/* Glassmorphism toolbar */} @@ -702,7 +1006,7 @@ export default function NetworkMap() { - + Hunt setLabelMode(e.target.value as LabelMode)} + sx={{ '& .MuiSelect-select': { py: 0.8 } }} + > + None + Selected/Search + All + + + + loadGraph(selectedHuntId)} + onClick={() => loadGraph(selectedHuntId, true)} disabled={loading || !selectedHuntId} size="small" sx={{ bgcolor: 'rgba(96,165,250,0.1)', '&:hover': { bgcolor: 'rgba(96,165,250,0.2)' } }} @@ -748,7 +1067,7 @@ export default function NetworkMap() { {/* Stats summary cards */} {stats && !loading && ( - + {[ { label: 'Hosts', value: stats.total_hosts, color: TYPE_COLORS.host }, { label: 'With IPs', value: stats.hosts_with_ips, color: '#34d399' }, @@ -773,14 +1092,14 @@ export default function NetworkMap() { {/* Loading indicator */} - + {progress} @@ -788,21 +1107,22 @@ export default function NetworkMap() { - {error && {error}} + {error && {error}} - {/* Canvas area */} + {/* Canvas area - takes ALL remaining space */} + {graph && ( } - label={`Hosts (${hostCount})`} + label={'Hosts (' + hostCount + ')'} size="small" sx={{ bgcolor: TYPE_COLORS.host + '22', color: TYPE_COLORS.host, - border: `1.5px solid ${TYPE_COLORS.host}88`, + border: '1.5px solid ' + TYPE_COLORS.host + '88', fontWeight: 600, fontSize: 11, }} /> {extCount > 0 && ( @@ -868,7 +1188,7 @@ export default function NetworkMap() { {[ { tip: 'Zoom in', icon: , fn: () => zoomBy(1.3) }, { tip: 'Zoom out', icon: , fn: () => zoomBy(1 / 1.3) }, - { tip: 'Reset view', icon: , fn: resetView }, + { tip: 'Fit to view', icon: , fn: fitView }, ].map(z => ( ))} )} + {/* Node detail popover */} {selectedNode && ( - + @@ -932,7 +1253,7 @@ export default function NetworkMap() { bgcolor: TYPE_COLORS[selectedNode.meta.type] + '22', color: TYPE_COLORS[selectedNode.meta.type], fontWeight: 700, fontSize: 10, height: 22, - border: `1px solid ${TYPE_COLORS[selectedNode.meta.type]}44`, + border: '1px solid ' + TYPE_COLORS[selectedNode.meta.type] + '44', }} /> @@ -958,7 +1279,7 @@ export default function NetworkMap() { IP Address - {selectedNode.meta.ips.length > 0 ? selectedNode.meta.ips.join(', ') : No IP detected} + {selectedNode.meta.ips.length > 0 ? selectedNode.meta.ips.join(', ') : 'No IP detected'} @@ -967,7 +1288,7 @@ export default function NetworkMap() { Operating System - {selectedNode.meta.os || Unknown} + {selectedNode.meta.os || 'Unknown'} @@ -986,9 +1307,7 @@ export default function NetworkMap() { ))} ) : ( - - No user data - + No user data )} @@ -1007,11 +1326,11 @@ export default function NetworkMap() { - - - @@ -1026,7 +1345,7 @@ export default function NetworkMap() { sx={{ fontSize: 10, height: 22, fontFamily: 'monospace', bgcolor: TYPE_COLORS[cn.type] + '15', color: TYPE_COLORS[cn.type], - border: `1px solid ${TYPE_COLORS[cn.type]}33`, cursor: 'pointer', + border: '1px solid ' + TYPE_COLORS[cn.type] + '33', cursor: 'pointer', '&:hover': { bgcolor: TYPE_COLORS[cn.type] + '30' }, }} onClick={() => { setSearch(cn.id); closePopover(); }} @@ -1053,27 +1372,29 @@ export default function NetworkMap() { )} - {/* Empty states */} + {/* Empty states - also fill remaining space */} {!selectedHuntId && !loading && ( - - - - Select a hunt to visualize its network - - - Choose a hunt from the dropdown above. The map builds a clean, - deduplicated host inventory showing each endpoint with its hostname, - IP address, OS, and logged-in users. - + + + + Select a hunt to visualize its network + + + Choose a hunt from the dropdown above. The map builds a clean, + deduplicated host inventory showing each endpoint with its hostname, + IP address, OS, and logged-in users. + + )} {selectedHuntId && !graph && !loading && !error && ( @@ -1081,6 +1402,7 @@ export default function NetworkMap() { )} + ); } \ No newline at end of file diff --git a/frontend/src/components/PlaybookManager.tsx b/frontend/src/components/PlaybookManager.tsx new file mode 100644 index 0000000..4ef7b47 --- /dev/null +++ b/frontend/src/components/PlaybookManager.tsx @@ -0,0 +1,237 @@ +/** + * PlaybookManager - Investigation playbook workflow wizard. + * Create/load playbooks from templates, track step completion, navigate to target views. + */ +import React, { useState, useEffect, useCallback } from 'react'; +import { + Box, Typography, Paper, CircularProgress, Alert, Button, Chip, + List, ListItem, ListItemButton, ListItemIcon, ListItemText, + Checkbox, Dialog, DialogTitle, DialogContent, DialogActions, + TextField, LinearProgress, IconButton, Divider, Tooltip, +} from '@mui/material'; +import AddIcon from '@mui/icons-material/Add'; +import DeleteIcon from '@mui/icons-material/Delete'; +import PlaylistAddCheckIcon from '@mui/icons-material/PlaylistAddCheck'; +import OpenInNewIcon from '@mui/icons-material/OpenInNew'; +import { useSnackbar } from 'notistack'; +import { + playbooks, PlaybookSummary, PlaybookDetail, PlaybookTemplate, +} from '../api/client'; + +export default function PlaybookManager() { + const { enqueueSnackbar } = useSnackbar(); + const [loading, setLoading] = useState(false); + const [pbList, setPbList] = useState([]); + const [active, setActive] = useState(null); + const [templates, setTemplates] = useState([]); + const [showCreate, setShowCreate] = useState(false); + const [newName, setNewName] = useState(''); + const [newDesc, setNewDesc] = useState(''); + + const loadList = useCallback(async () => { + setLoading(true); + try { + const data = await playbooks.list(); + setPbList(data.playbooks); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setLoading(false); + } + }, [enqueueSnackbar]); + + const loadTemplates = useCallback(async () => { + try { + const data = await playbooks.templates(); + setTemplates(data.templates); + } catch {} + }, []); + + useEffect(() => { loadList(); loadTemplates(); }, [loadList, loadTemplates]); + + const selectPlaybook = async (id: string) => { + try { + const d = await playbooks.get(id); + setActive(d); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const toggleStep = async (stepId: number, current: boolean) => { + if (!active) return; + try { + await playbooks.updateStep(stepId, { is_completed: !current }); + const d = await playbooks.get(active.id); + setActive(d); + loadList(); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const createFromTemplate = async (tpl: PlaybookTemplate) => { + try { + const pb = await playbooks.create({ + name: tpl.name, + description: tpl.description, + steps: tpl.steps.map((s, i) => ({ + title: s.title, + description: s.description, + step_type: 'task', + target_route: s.target_route || undefined, + })), + }); + enqueueSnackbar('Playbook created from template', { variant: 'success' }); + loadList(); + setActive(pb); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const createCustom = async () => { + if (!newName.trim()) return; + try { + const pb = await playbooks.create({ + name: newName, + description: newDesc, + steps: [{ title: 'First step', description: 'Describe what to do' }], + }); + enqueueSnackbar('Playbook created', { variant: 'success' }); + setShowCreate(false); + setNewName(''); + setNewDesc(''); + loadList(); + setActive(pb); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const deletePlaybook = async (id: string) => { + try { + await playbooks.delete(id); + enqueueSnackbar('Playbook deleted', { variant: 'success' }); + if (active?.id === id) setActive(null); + loadList(); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const completedCount = active?.steps.filter(s => s.is_completed).length || 0; + const totalSteps = active?.steps.length || 1; + const progress = Math.round((completedCount / totalSteps) * 100); + + return ( + + {/* Left sidebar - playbook list */} + + + Playbooks + setShowCreate(true)}> + + + {/* Templates section */} + TEMPLATES + + {templates.map(t => ( + createFromTemplate(t)} sx={{ borderRadius: 1, mb: 0.5 }}> + + + + ))} + + + + MY PLAYBOOKS + {loading && } + + {pbList.map(p => ( + selectPlaybook(p.id)} sx={{ borderRadius: 1, mb: 0.5 }}> + + + + { e.stopPropagation(); deletePlaybook(p.id); }}> + + + ))} + {!loading && pbList.length === 0 && ( + No playbooks yet. Start from a template or create one. + )} + + + + {/* Right panel - active playbook */} + + {!active ? ( + + + Select or create a playbook + + Use templates for common investigation workflows, or build your own step-by-step checklist. + + + ) : ( + + {active.name} + {active.description && {active.description}} + + + + {progress}% + + + + + {active.steps + .sort((a, b) => a.order_index - b.order_index) + .map(step => ( + + toggleStep(step.id, step.is_completed)} sx={{ borderRadius: 1, border: '1px solid', borderColor: step.is_completed ? 'success.main' : 'divider', bgcolor: step.is_completed ? 'success.main' : 'transparent', opacity: step.is_completed ? 0.7 : 1 }}> + + + + + {step.target_route && ( + + { e.stopPropagation(); window.location.hash = step.target_route!; }}> + + + + )} + + + ))} + + + )} + + + {/* Create dialog */} + setShowCreate(false)} maxWidth="sm" fullWidth> + Create Custom Playbook + + setNewName(e.target.value)} sx={{ mt: 1, mb: 2 }} /> + setNewDesc(e.target.value)} /> + + + + + + + + ); +} + diff --git a/frontend/src/components/SavedSearches.tsx b/frontend/src/components/SavedSearches.tsx new file mode 100644 index 0000000..e0825a0 --- /dev/null +++ b/frontend/src/components/SavedSearches.tsx @@ -0,0 +1,271 @@ +/** + * SavedSearches - Manage bookmarked queries and recurring scans. + * Supports IOC, keyword, NLP, and correlation search types with delta tracking. + */ +import React, { useState, useEffect, useCallback } from 'react'; +import { + Box, Typography, Paper, CircularProgress, Alert, Button, Chip, + Table, TableHead, TableRow, TableCell, TableBody, TableContainer, + Dialog, DialogTitle, DialogContent, DialogActions, + TextField, FormControl, InputLabel, Select, MenuItem, + IconButton, Tooltip, +} from '@mui/material'; +import AddIcon from '@mui/icons-material/Add'; +import DeleteIcon from '@mui/icons-material/Delete'; +import PlayArrowIcon from '@mui/icons-material/PlayArrow'; +import EditIcon from '@mui/icons-material/Edit'; +import BookmarkIcon from '@mui/icons-material/Bookmark'; +import { useSnackbar } from 'notistack'; +import { savedSearches, SavedSearchData, SearchRunResult } from '../api/client'; + +const SEARCH_TYPES = [ + { value: 'ioc_search', label: 'IOC Search' }, + { value: 'keyword_scan', label: 'Keyword Scan' }, + { value: 'nlp_query', label: 'NLP Query' }, + { value: 'correlation', label: 'Correlation' }, +]; + +function typeColor(t: string): 'primary' | 'secondary' | 'warning' | 'info' { + switch (t) { + case 'ioc_search': return 'primary'; + case 'keyword_scan': return 'warning'; + case 'nlp_query': return 'info'; + case 'correlation': return 'secondary'; + default: return 'primary'; + } +} + +export default function SavedSearchesView() { + const { enqueueSnackbar } = useSnackbar(); + const [loading, setLoading] = useState(false); + const [items, setItems] = useState([]); + const [showForm, setShowForm] = useState(false); + const [editing, setEditing] = useState(null); + const [runResult, setRunResult] = useState(null); + const [runId, setRunId] = useState(null); + const [running, setRunning] = useState(null); + + // Form state + const [name, setName] = useState(''); + const [searchType, setSearchType] = useState('ioc_search'); + const [queryParams, setQueryParams] = useState(''); + const [huntId, setHuntId] = useState(''); + + const load = useCallback(async () => { + setLoading(true); + try { + const data = await savedSearches.list(); + setItems(data.searches); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setLoading(false); + } + }, [enqueueSnackbar]); + + useEffect(() => { load(); }, [load]); + + const openCreate = () => { + setEditing(null); + setName(''); + setSearchType('ioc_search'); + setQueryParams(''); + setHuntId(''); + setShowForm(true); + }; + + const openEdit = (item: SavedSearchData) => { + setEditing(item); + setName(item.name); + setSearchType(item.search_type); + setQueryParams(JSON.stringify(item.query_params, null, 2)); + setHuntId((item.query_params as any)?.hunt_id || ''); + setShowForm(true); + }; + + const save = async () => { + if (!name.trim()) return; + let params: Record = {}; + try { + params = JSON.parse(queryParams || '{}'); + } catch { + enqueueSnackbar('Invalid JSON in query parameters', { variant: 'error' }); + return; + } + try { + if (editing) { + await savedSearches.update(editing.id, { + name, search_type: searchType, query_params: params, + hunt_id: huntId || undefined, + }); + enqueueSnackbar('Search updated', { variant: 'success' }); + } else { + await savedSearches.create({ + name, search_type: searchType, query_params: params, + hunt_id: huntId || undefined, + }); + enqueueSnackbar('Search saved', { variant: 'success' }); + } + setShowForm(false); + load(); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const remove = async (id: string) => { + try { + await savedSearches.delete(id); + enqueueSnackbar('Deleted', { variant: 'success' }); + load(); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } + }; + + const runSearch = async (id: string) => { + setRunning(id); + try { + const result = await savedSearches.run(id); + setRunResult(result); + setRunId(id); + load(); // refresh last_run times + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setRunning(null); + } + }; + + return ( + + + + Saved Searches + + + + {loading && } + + {!loading && items.length === 0 && ( + + No saved searches yet. Create one to bookmark frequently-used queries for quick re-execution. + + )} + + {items.length > 0 && ( + +
+ + + Name + Type + Hunt ID + Last Run + Last Count + Actions + + + + {items.map(item => ( + + {item.name} + + t.value === item.search_type)?.label || item.search_type} + color={typeColor(item.search_type)} size="small" sx={{ fontSize: '0.7rem' }} /> + + + {(item.query_params as any)?.hunt_id ? String((item.query_params as any).hunt_id).slice(0, 8) + '...' : 'All'} + + + {item.last_run_at ? new Date(item.last_run_at).toLocaleString() : 'Never'} + + + {item.last_result_count != null ? ( + 0 ? 'warning' : 'default'} /> + ) : ''} + + + + runSearch(item.id)} + disabled={running === item.id}> + {running === item.id ? : } + + + + openEdit(item)}> + + + remove(item.id)}> + + + + ))} + +
+
+ )} + + {/* Run result dialog */} + setRunResult(null)} maxWidth="sm" fullWidth> + Search Results + + {runResult && ( + + + Search: {items.find(i => i.id === runId)?.name} + + + 0 ? 'warning' : 'success'} /> + {runResult.delta !== undefined && runResult.delta !== null && ( + = 0 ? '+' : ''}${runResult.delta} since last run`} + color={runResult.delta > 0 ? 'error' : 'default'} variant="outlined" /> + )} + + {runResult.results && runResult.results.length > 0 && ( + + Preview (first {runResult.results.length} results): + {runResult.results.map((item: any, i: number) => ( + + {typeof item === 'string' ? item : JSON.stringify(item, null, 1)} + + ))} + + )} + + )} + + + + + + + {/* Create/Edit dialog */} + setShowForm(false)} maxWidth="sm" fullWidth> + {editing ? 'Edit Search' : 'Create Saved Search'} + + setName(e.target.value)} sx={{ mt: 1, mb: 2 }} /> + + Search Type + + + setHuntId(e.target.value)} sx={{ mb: 2 }} + placeholder="Leave empty to search all hunts" /> + setQueryParams(e.target.value)} + placeholder='{"keywords": ["mimikatz", "lsass"]}' + helperText="JSON object with search-specific parameters" /> + + + + + + +
+ ); +} + diff --git a/frontend/src/components/TimelineView.tsx b/frontend/src/components/TimelineView.tsx new file mode 100644 index 0000000..0796573 --- /dev/null +++ b/frontend/src/components/TimelineView.tsx @@ -0,0 +1,178 @@ +/** + * TimelineView - Forensic event timeline with zoomable chart. + * Plots dataset rows on a time axis, color-coded by artifact type and risk. + */ +import React, { useState, useEffect, useCallback } from 'react'; +import { + Box, Typography, Paper, CircularProgress, Alert, Chip, + FormControl, InputLabel, Select, MenuItem, IconButton, + Table, TableHead, TableRow, TableCell, TableBody, TableContainer, +} from '@mui/material'; +import RefreshIcon from '@mui/icons-material/Refresh'; +import { useSnackbar } from 'notistack'; +import { + BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip as ReTooltip, + ResponsiveContainer, +} from 'recharts'; +import { timeline, TimelineData, TimelineEvent, hunts, Hunt } from '../api/client'; + +const ARTIFACT_COLORS: Record = { + 'Windows.System.Pslist': '#60a5fa', + 'Windows.Network.Netstat': '#f472b6', + 'Windows.System.Services': '#34d399', + 'Windows.Forensics.Prefetch': '#fbbf24', + 'Windows.EventLogs.EvtxHunter': '#a78bfa', + 'Windows.Sys.Autoruns': '#f87171', + 'Unknown': '#64748b', +}; + +function getColor(artifact: string): string { + return ARTIFACT_COLORS[artifact] || ARTIFACT_COLORS['Unknown']; +} + +function bucketEvents(events: TimelineEvent[], buckets = 50): { time: string; count: number; artifacts: Record }[] { + if (!events.length) return []; + const sorted = [...events].sort((a, b) => a.timestamp.localeCompare(b.timestamp)); + const start = new Date(sorted[0].timestamp).getTime(); + const end = new Date(sorted[sorted.length - 1].timestamp).getTime(); + const span = Math.max(end - start, 1); + const bucketSize = span / buckets; + const result: { time: string; count: number; artifacts: Record }[] = []; + for (let i = 0; i < buckets; i++) { + const bStart = start + i * bucketSize; + const bEnd = bStart + bucketSize; + const inBucket = sorted.filter(e => { + const t = new Date(e.timestamp).getTime(); + return t >= bStart && t < bEnd; + }); + const artifacts: Record = {}; + inBucket.forEach(e => { artifacts[e.artifact_type] = (artifacts[e.artifact_type] || 0) + 1; }); + result.push({ + time: new Date(bStart).toISOString().slice(0, 16).replace('T', ' '), + count: inBucket.length, + artifacts, + }); + } + return result; +} + +export default function TimelineView() { + const { enqueueSnackbar } = useSnackbar(); + const [loading, setLoading] = useState(false); + const [data, setData] = useState(null); + const [huntList, setHuntList] = useState([]); + const [selectedHunt, setSelectedHunt] = useState(''); + const [filterArtifact, setFilterArtifact] = useState(''); + + const load = useCallback(async () => { + if (!selectedHunt) return; + setLoading(true); + try { + const d = await timeline.getHuntTimeline(selectedHunt); + setData(d); + } catch (e: any) { + enqueueSnackbar(e.message, { variant: 'error' }); + } finally { + setLoading(false); + } + }, [selectedHunt, enqueueSnackbar]); + + useEffect(() => { + hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {}); + }, []); + + useEffect(() => { load(); }, [load]); + + const filteredEvents = data?.events.filter(e => !filterArtifact || e.artifact_type === filterArtifact) || []; + const buckets = bucketEvents(filteredEvents); + const artifactTypes = [...new Set(data?.events.map(e => e.artifact_type) || [])]; + + return ( + + + Forensic Timeline + + Hunt + + + + Artifact Type + + + + {data && } + + + {!selectedHunt && ( + Select a hunt to view its forensic timeline. + )} + + {loading && } + + {!loading && data && filteredEvents.length === 0 && ( + No timestamped events found in this hunt's datasets. + )} + + {!loading && data && filteredEvents.length > 0 && ( + <> + {/* Activity histogram */} + + Activity Over Time + + {artifactTypes.map(a => ( + setFilterArtifact(filterArtifact === a ? '' : a)} variant={filterArtifact === a ? 'filled' : 'outlined'} /> + ))} + + + + + + + + + + + + + {/* Event table */} + + Events ({filteredEvents.length}) + + + + + Time + Hostname + Artifact + Process + Summary + + + + {filteredEvents.slice(0, 500).map((e, i) => ( + + {e.timestamp.replace('T', ' ').slice(0, 19)} + {e.hostname || ''} + + + + {e.process || ''} + {e.summary || ''} + + ))} + +
+
+
+ + )} +
+ ); +} diff --git a/write_update.py b/write_update.py new file mode 100644 index 0000000..69fa78a --- /dev/null +++ b/write_update.py @@ -0,0 +1,164 @@ +import os + +lines = [] +a = lines.append + +a("# ThreatHunt Update Log") +a("") +a("## 2026-02-22: Full Auto-Processing Pipeline, Performance Fixes, DB Concurrency") +a("") +a("### Auto-Processing Pipeline (Import-Time)") +a("- **Problem**: Only HOST_INVENTORY ran on dataset upload. Triage, anomaly detection, keyword scanning, and IOC extraction were manual-only, effectively dead code.") +a("- **Solution**: Wired ALL processing modules into the upload endpoint. On CSV import, 5 jobs are now auto-queued: TRIAGE, ANOMALY, KEYWORD_SCAN, IOC_EXTRACT, HOST_INVENTORY.") +a("- **Startup reprocessing**: On backend boot, queries for datasets with no anomaly results and queues the full pipeline for them.") +a("- **Completion tracking**: Pipeline completion callback updates `Dataset.processing_status` to `completed` or `completed_with_errors` when all 4 analysis jobs finish.") +a("- **Triage chaining**: After triage completes, automatically queues a HOST_PROFILE job for deep per-host LLM analysis.") +a("") +a("### Artifact Classification (Was Dead Code)") +a("- **Problem**: `classify_artifact()` in `artifact_classifier.py` existed but was never called.") +a("- **Fix**: Upload endpoint now calls `classify_artifact(columns)` to identify Velociraptor artifact types (30+ fingerprints) and stores `artifact_type` on the dataset.") +a("") +a("### Database Concurrency Fix") +a("- **Problem**: SQLite with `StaticPool` = single shared connection. Any long-running job (keyword scan, triage) blocked ALL other DB queries, freezing the entire app.") +a("- **Fix**: Switched to `NullPool` so each async session gets its own connection. Combined with WAL mode (`PRAGMA journal_mode=WAL`), `busy_timeout=30000`, and `synchronous=NORMAL` for concurrent reads during writes.") +a("") +a("#### Modified: `backend/app/db/engine.py`") +a("- `StaticPool` -> `NullPool` for SQLite") +a("- Added `_set_sqlite_pragmas` event listener: WAL mode, 30s busy timeout, NORMAL sync") +a("- Connection args: `timeout=60`, `check_same_thread=False`") +a("") +a("### Triage Model Fix") +a("- **Problem**: `triage.py` hardcoded `DEFAULT_FAST_MODEL = \"qwen2.5-coder:7b-instruct-q4_K_M\"` which didn't exist on Roadrunner, causing 404 errors on every triage batch.") +a("- **Fix**: Changed to `settings.DEFAULT_FAST_MODEL` which resolves to `llama3.1:latest` (available on Roadrunner). Configurable via `TH_DEFAULT_FAST_MODEL` env var.") +a("") +a("### Host Profiler ClientID Fix") +a("- **Problem**: Velociraptor ClientID-format hostnames (`C.82465a50d075ea20`) were sent to the LLM for profiling, producing empty/useless results.") +a("- **Fix**: Added regex filter `^C\\.[0-9a-fA-F]{8,}$` to skip ClientID entries before profiling.") +a("") +a("### Job Queue Expansion") +a("- **Before**: 3 job types (TRIAGE, HOST_PROFILE, REPORT), 3 workers") +a("- **After**: 8 job types, 5 workers, pipeline completion callbacks") +a("- Added: KEYWORD_SCAN, IOC_EXTRACT to JobType enum") +a("- Added: `PIPELINE_JOB_TYPES` frozenset (TRIAGE, ANOMALY, KEYWORD_SCAN, IOC_EXTRACT)") +a("- Added: `_on_pipeline_job_complete` callback updates `processing_status`") +a("- Added: `_handle_keyword_scan` using `KeywordScanner(db).scan()`") +a("- Added: `_handle_ioc_extract` using `extract_iocs_from_dataset()`") +a("- Triage now chains HOST_PROFILE after completion") +a("") +a("#### Modified: `backend/app/api/routes/datasets.py`") +a("- Upload calls `classify_artifact(columns)` for artifact type detection") +a("- Sets `artifact_type` and `processing_status=\"processing\"` on create") +a("- Queues 5 jobs: TRIAGE, ANOMALY, KEYWORD_SCAN, IOC_EXTRACT, HOST_INVENTORY") +a("- `UploadResponse` includes `artifact_type`, `processing_status`, `jobs_queued`") +a("") +a("#### Modified: `backend/app/main.py`") +a("- Startup reprocessing: finds datasets with no `AnomalyResult` records, queues full pipeline") +a("- Marks reprocessed datasets as `processing_status=\"processing\"`") +a("- Logs skip message when all datasets already processed") +a("") +a("### Network Map Performance Fix") +a("- **Problem**: 163 hosts + 1121 connections created 528 total nodes (365 external IPs). The O(N^2) force simulation did 278,784 pairwise calculations per animation frame, freezing the browser.") +a("- **Fix**: 6 optimizations applied to `frontend/src/components/NetworkMap.tsx`:") +a("") +a("| Fix | Detail |") +a("|-----|--------|") +a("| Cap external IPs | `MAX_EXTERNAL_NODES = 30` (was unlimited: 365) |") +a("| Sampling simulation | For N > 150 nodes, sample 40 random per node instead of N^2 pairs |") +a("| Distance cutoff | Skip repulsion for pairs > 600px apart |") +a("| Single redraw on hover | Was restarting full animation loop on every mouse hover |") +a("| Faster alpha decay | 0.97 -> 0.93 per frame (settles ~2x faster) |") +a("| Lower initial energy | simAlpha 0.6 -> 0.3, sim steps 80 -> 60 |") +a("") +a("### Test Results") +a("- **79/79 backend tests passing** (0.72s)") +a("- Both Docker containers healthy") +a("- 21/21 frontend-facing endpoints return 200 OK through nginx") +a("") +a("### Endpoint Verification (via nginx on port 3000)") +a("") +a("| Endpoint | Status | Size |") +a("|----------|--------|------|") +a("| /api/agent/health | 200 | 522b |") +a("| /api/hunts | 200 | 259b |") +a("| /api/datasets?hunt_id=... | 200 | 23KB |") +a("| /api/datasets/{id}/rows | 200 | 144KB |") +a("| /api/analysis/anomalies/{id} | 200 | 104KB |") +a("| /api/analysis/iocs/{id} | 200 | 1.2KB |") +a("| /api/analysis/triage/{id} | 200 | 9.5KB |") +a("| /api/analysis/profiles/{hunt} | 200 | 177KB |") +a("| /api/network/host-inventory | 200 | 181KB |") +a("| /api/timeline/hunt/{hunt} | 200 | 351KB |") +a("| /api/keywords/themes | 200 | 23KB |") +a("| /api/playbooks/templates | 200 | 2.5KB |") +a("| /api/reports/hunt/{hunt} | 200 | 10.6KB |") +a("| /api/export/stix/{hunt} | 200 | 391b |") +a("") +a("---") +a("") +a("## 2026-02-21: Feature Expansion, Dashboard Rewrite, Docker Deployment") +a("") +a("### New Features Added") +a("- **MITRE ATT&CK Matrix** (`/api/mitre/coverage`, `MitreMatrix.tsx`) - technique coverage visualization") +a("- **Timeline View** (`/api/timeline/hunt/{hunt}`, `TimelineView.tsx`) - chronological event explorer") +a("- **Playbook Manager** (`/api/playbooks`, `PlaybookManager.tsx`) - investigation playbook CRUD with templates") +a("- **Saved Searches** (`/api/searches`, `SavedSearches.tsx`) - save/run named queries") +a("- **STIX Export** (`/api/export/stix/{hunt}`) - STIX 2.1 bundle export for threat intel sharing") +a("") +a("### DB Models Added") +a("- `Playbook`, `PlaybookStep` - investigation playbook tracking") +a("- `SavedSearch` - persisted named queries") +a("") +a("### Dashboard & Correlation Rewrite") +a("- `Dashboard.tsx` - rewrote with live stat cards, dataset table, processing status indicators") +a("- `CorrelationView.tsx` - rewrote with working correlation analysis UI") +a("- `AgentPanel.tsx` - added SSE streaming for real-time agent responses") +a("") +a("### Docker Deployment") +a("- `Dockerfile.frontend` - added `TSC_COMPILE_ON_ERROR=true` for MUI X v8 compatibility") +a("- `nginx.conf` - SSE proxy headers, 500MB upload, 300s proxy timeout, SPA fallback") +a("- Frontend healthcheck changed from wget to curl with 127.0.0.1") +a("") +a("---") +a("") +a("## 2026-02-20: Host-Centric Network Map & Analysis Platform") +a("") +a("### Network Map Overhaul") +a("- **Problem**: Network Map showed 409 misclassified domain nodes (mostly process names like svchost.exe) and 0 hosts. No deduplication.") +a("- **Root Cause**: IOC column detection misclassified `Fqdn` as domain instead of hostname; `Name` column (process names) wrongly tagged as domain IOC.") +a("- **Solution**: Created host-centric inventory system. Scans all datasets, groups by `Fqdn`/`ClientId`, extracts IPs, users, OS, and network connections.") +a("") +a("#### New Backend Files") +a("- `host_inventory.py` - Deduplicated host inventory builder with in-memory cache, background job pattern (202 polling), 5000-row batches") +a("- `network.py` routes - `GET /api/network/host-inventory`, `/inventory-status`, `/rebuild-inventory`") +a("- `ioc_extractor.py` - Regex IOC extraction (IP, domain, hash, email, URL)") +a("- `anomaly_detector.py` - Embedding-based outlier detection via bge-m3") +a("- `data_query.py` - Natural language to structured query translation") +a("- `load_balancer.py` - Round-robin load balancer for Ollama LLM nodes") +a("- `job_queue.py` - Async job queue (initially 3 workers, 3 job types)") +a("- `analysis.py` routes - 16 analysis endpoints") +a("") +a("#### Frontend") +a("- `NetworkMap.tsx` - Canvas 2D force-directed graph, HiDPI, node dragging, search, popover, module-level cache") +a("- `AnalysisDashboard.tsx` - 6-tab analysis dashboard") +a("- `client.ts` - `network.*` and `analysis.*` API namespaces") +a("") +a("### Results (Radio Hunt - 20 Velociraptor datasets, 394K rows)") +a("") +a("| Metric | Before | After |") +a("|--------|--------|-------|") +a("| Nodes shown | 409 misclassified domains | **163 unique hosts** |") +a("| Hosts identified | 0 | **163** |") +a("| With IP addresses | N/A | **48** (172.17.x.x LAN) |") +a("| With logged-in users | N/A | **43** (real names only) |") +a("| OS detected | None | **Windows 10** (inferred) |") +a("| Deduplication | None | **Full** (by FQDN/ClientId) |") +a("") +a("### LLM Infrastructure") +a("- **Roadrunner** (100.110.190.11:11434): llama3.1:latest, qwen2.5-coder:7b, qwen2.5:14b, bge-m3 embeddings") +a("- **Wile** (100.110.190.12:11434): llama3.1:70b-instruct-q4_K_M (heavy analysis)") +a("- **Open WebUI** (ai.guapo613.beer): Cluster management interface") + +path = r'd:\Projects\Dev\ThreatHunt\update.md' +with open(path, 'w', encoding='utf-8') as f: + f.write('\n'.join(lines) + '\n') +print(f'Written {len(lines)} lines to update.md') \ No newline at end of file