mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates. - Added SavedSearches component for managing bookmarked queries and recurring scans. - Introduced TimelineView for visualizing forensic event timelines with zoomable charts. - Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency. - Updated frontend components for better user experience and performance optimizations. - Documented changes in update log for future reference.
This commit is contained in:
@@ -17,7 +17,7 @@ COPY frontend/tsconfig.json ./
|
|||||||
# Build application
|
# Build application
|
||||||
RUN npm run build
|
RUN npm run build
|
||||||
|
|
||||||
# Production stage — nginx reverse-proxy + static files
|
# Production stage — nginx reverse-proxy + static files
|
||||||
FROM nginx:alpine
|
FROM nginx:alpine
|
||||||
|
|
||||||
# Copy built React app
|
# Copy built React app
|
||||||
|
|||||||
148
_add_label_filter_networkmap.py
Normal file
148
_add_label_filter_networkmap.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# 1) Add label mode type near graph types
|
||||||
|
marker="interface GEdge { source: string; target: string; weight: number }\ninterface Graph { nodes: GNode[]; edges: GEdge[] }\n"
|
||||||
|
if marker in t and "type LabelMode" not in t:
|
||||||
|
t=t.replace(marker, marker+"\ntype LabelMode = 'all' | 'highlight' | 'none';\n")
|
||||||
|
|
||||||
|
# 2) extend drawLabels signature
|
||||||
|
old_sig="""function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
"""
|
||||||
|
new_sig="""function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
simplify: boolean, labelMode: LabelMode,
|
||||||
|
) {
|
||||||
|
"""
|
||||||
|
if old_sig in t:
|
||||||
|
t=t.replace(old_sig,new_sig)
|
||||||
|
|
||||||
|
# 3) label mode guards inside drawLabels
|
||||||
|
old_guard=""" const dimmed = search.length > 0;
|
||||||
|
if (simplify && !search && !hovered && !selected) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
new_guard=""" if (labelMode === 'none') return;
|
||||||
|
const dimmed = search.length > 0;
|
||||||
|
if (labelMode === 'highlight' && !search && !hovered && !selected) return;
|
||||||
|
if (simplify && labelMode !== 'all' && !search && !hovered && !selected) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if old_guard in t:
|
||||||
|
t=t.replace(old_guard,new_guard)
|
||||||
|
|
||||||
|
old_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
|
||||||
|
const show = isHighlight || n.meta.type === 'host' || n.count >= 2;
|
||||||
|
if (!show) continue;
|
||||||
|
"""
|
||||||
|
new_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
|
||||||
|
const show = labelMode === 'all'
|
||||||
|
? (isHighlight || n.meta.type === 'host' || n.count >= 2)
|
||||||
|
: isHighlight;
|
||||||
|
if (!show) continue;
|
||||||
|
"""
|
||||||
|
if old_show in t:
|
||||||
|
t=t.replace(old_show,new_show)
|
||||||
|
|
||||||
|
# 4) drawGraph signature and call site
|
||||||
|
old_graph_sig="""function drawGraph(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null, search: string,
|
||||||
|
vp: Viewport, animTime: number, dpr: number,
|
||||||
|
) {
|
||||||
|
"""
|
||||||
|
new_graph_sig="""function drawGraph(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null, search: string,
|
||||||
|
vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode,
|
||||||
|
) {
|
||||||
|
"""
|
||||||
|
if old_graph_sig in t:
|
||||||
|
t=t.replace(old_graph_sig,new_graph_sig)
|
||||||
|
|
||||||
|
old_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);"
|
||||||
|
new_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode);"
|
||||||
|
if old_drawlabels_call in t:
|
||||||
|
t=t.replace(old_drawlabels_call,new_drawlabels_call)
|
||||||
|
|
||||||
|
# 5) state for label mode
|
||||||
|
state_anchor=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n"
|
||||||
|
state_new=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n const [labelMode, setLabelMode] = useState<LabelMode>('highlight');\n"
|
||||||
|
if state_anchor in t:
|
||||||
|
t=t.replace(state_anchor,state_new)
|
||||||
|
|
||||||
|
# 6) pass labelMode in draw calls
|
||||||
|
old_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);"
|
||||||
|
new_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode);"
|
||||||
|
if old_tick_draw in t:
|
||||||
|
t=t.replace(old_tick_draw,new_tick_draw)
|
||||||
|
|
||||||
|
old_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);"
|
||||||
|
new_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode);"
|
||||||
|
if old_redraw_draw in t:
|
||||||
|
t=t.replace(old_redraw_draw,new_redraw_draw)
|
||||||
|
|
||||||
|
# 7) include labelMode in redraw deps
|
||||||
|
old_redraw_dep="] , [graph, hovered, selectedNode, search]);"
|
||||||
|
if old_redraw_dep in t:
|
||||||
|
t=t.replace(old_redraw_dep, "] , [graph, hovered, selectedNode, search, labelMode]);")
|
||||||
|
else:
|
||||||
|
t=t.replace(" }, [graph, hovered, selectedNode, search]);"," }, [graph, hovered, selectedNode, search, labelMode]);")
|
||||||
|
|
||||||
|
# 8) Add toolbar selector after search field
|
||||||
|
search_block=""" <TextField
|
||||||
|
size="small"
|
||||||
|
placeholder="Search hosts, IPs, users\u2026"
|
||||||
|
value={search}
|
||||||
|
onChange={e => setSearch(e.target.value)}
|
||||||
|
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||||
|
slotProps={{
|
||||||
|
input: {
|
||||||
|
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
"""
|
||||||
|
label_block=""" <TextField
|
||||||
|
size="small"
|
||||||
|
placeholder="Search hosts, IPs, users\u2026"
|
||||||
|
value={search}
|
||||||
|
onChange={e => setSearch(e.target.value)}
|
||||||
|
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||||
|
slotProps={{
|
||||||
|
input: {
|
||||||
|
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<FormControl size="small" sx={{ minWidth: 140 }}>
|
||||||
|
<InputLabel id="label-mode-selector">Labels</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId="label-mode-selector"
|
||||||
|
value={labelMode}
|
||||||
|
label="Labels"
|
||||||
|
onChange={e => setLabelMode(e.target.value as LabelMode)}
|
||||||
|
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
|
||||||
|
>
|
||||||
|
<MenuItem value="none">None</MenuItem>
|
||||||
|
<MenuItem value="highlight">Selected/Search</MenuItem>
|
||||||
|
<MenuItem value="all">All</MenuItem>
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
"""
|
||||||
|
if search_block in t:
|
||||||
|
t=t.replace(search_block,label_block)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added network map label filter control and renderer modes')
|
||||||
18
_add_scanner_budget_config.py
Normal file
18
_add_scanner_budget_config.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' # -- Scanner settings -----------------------------------------------
|
||||||
|
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||||
|
'''
|
||||||
|
new=''' # -- Scanner settings -----------------------------------------------
|
||||||
|
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||||
|
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||||
|
default=300000,
|
||||||
|
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('scanner settings block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added SCANNER_MAX_ROWS_PER_SCAN config')
|
||||||
46
_apply_frontend_scale_patch.py
Normal file
46
_apply_frontend_scale_patch.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||||
|
|
||||||
|
# -------- client.ts --------
|
||||||
|
client = root / "frontend/src/api/client.ts"
|
||||||
|
text = client.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
if "export interface NetworkSummary" not in text:
|
||||||
|
insert_after = "export interface InventoryStatus {\n hunt_id: string;\n status: 'ready' | 'building' | 'none';\n}\n"
|
||||||
|
addition = insert_after + "\nexport interface NetworkSummaryHost {\n id: string;\n hostname: string;\n row_count: number;\n ip_count: number;\n user_count: number;\n}\n\nexport interface NetworkSummary {\n stats: InventoryStats;\n top_hosts: NetworkSummaryHost[];\n top_edges: InventoryConnection[];\n status?: 'building' | 'deferred';\n message?: string;\n}\n"
|
||||||
|
text = text.replace(insert_after, addition)
|
||||||
|
|
||||||
|
net_old = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
|
||||||
|
net_new = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n summary: (huntId: string, topN = 20) =>\n api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),\n subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {\n let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;\n if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;\n return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);\n },\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
|
||||||
|
if net_old in text:
|
||||||
|
text = text.replace(net_old, net_new)
|
||||||
|
|
||||||
|
client.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# -------- NetworkMap.tsx --------
|
||||||
|
nm = root / "frontend/src/components/NetworkMap.tsx"
|
||||||
|
text = nm.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# add constants
|
||||||
|
if "LARGE_HUNT_HOST_THRESHOLD" not in text:
|
||||||
|
text = text.replace("let lastSelectedHuntId = '';\n", "let lastSelectedHuntId = '';\nconst LARGE_HUNT_HOST_THRESHOLD = 400;\nconst LARGE_HUNT_SUBGRAPH_HOSTS = 350;\nconst LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n")
|
||||||
|
|
||||||
|
# inject helper in component after sleep
|
||||||
|
marker = " const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
|
||||||
|
if "loadScaleAwareGraph" not in text:
|
||||||
|
helper = marker + "\n const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n setLoading(true); setError(''); setGraph(null); setStats(null);\n setSelectedNode(null); setPopoverAnchor(null);\n\n const waitReadyThen = async <T,>(fn: () => Promise<T>): Promise<T> => {\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const out: any = await fn();\n if (out && !out.status) return out as T;\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') {\n const out2: any = await fn();\n if (out2 && !out2.status) return out2 as T;\n }\n if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes');\n const jitter = Math.floor(Math.random() * 250);\n await sleep(delayMs + jitter);\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n };\n\n try {\n setProgress('Loading network summary');\n const summary: any = await waitReadyThen(() => network.summary(huntId, 20));\n const totalHosts = summary?.stats?.total_hosts || 0;\n\n if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) {\n setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`);\n const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES));\n if (!sub?.hosts || sub.hosts.length === 0) {\n setError('No hosts found for subgraph.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() });\n setGraph(g);\n return;\n }\n\n // Small/medium hunts: load full inventory\n setProgress('Loading host inventory');\n const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh));\n if (!inv?.hosts || inv.hosts.length === 0) {\n setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats || inv.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() });\n setGraph(g);\n } catch (e: any) {\n console.error('[NetworkMap] scale-aware load error:', e);\n setError(e.message || 'Failed to load network data');\n } finally {\n setLoading(false);\n setProgress('');\n }\n }, []);\n"
|
||||||
|
text = text.replace(marker, helper)
|
||||||
|
|
||||||
|
# simplify existing loadGraph function body to delegate
|
||||||
|
pattern_start = text.find(" // Load host inventory for selected hunt (with cache).")
|
||||||
|
if pattern_start != -1:
|
||||||
|
# replace the whole loadGraph useCallback block by simple delegator
|
||||||
|
import re
|
||||||
|
block_re = re.compile(r" // Load host inventory for selected hunt \(with cache\)\.[\s\S]*?\n \}, \[\]\); // Stable - reads canvasSizeRef, no state deps\n", re.M)
|
||||||
|
repl = " // Load graph data for selected hunt (delegates to scale-aware loader).\n const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n if (!huntId) return;\n\n // Check module-level cache first (5 min TTL)\n if (!forceRefresh) {\n const cached = graphCache.get(huntId);\n if (cached && Date.now() - cached.ts < 5 * 60 * 1000) {\n setGraph(cached.graph);\n setStats(cached.stats);\n setError('');\n simAlphaRef.current = 0;\n return;\n }\n }\n\n await loadScaleAwareGraph(huntId, forceRefresh);\n // eslint-disable-next-line react-hooks/exhaustive-deps\n }, []); // Stable - reads canvasSizeRef, no state deps\n"
|
||||||
|
text = block_re.sub(repl, text, count=1)
|
||||||
|
|
||||||
|
nm.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
print("Patched frontend client + NetworkMap for scale-aware loading")
|
||||||
206
_apply_phase1_patch.py
Normal file
206
_apply_phase1_patch.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||||
|
|
||||||
|
# 1) config.py additions
|
||||||
|
cfg = root / "backend/app/config.py"
|
||||||
|
text = cfg.read_text(encoding="utf-8")
|
||||||
|
needle = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n"
|
||||||
|
insert = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n\n # -- Job queue settings ----------------------------------------------\n JOB_QUEUE_MAX_BACKLOG: int = Field(\n default=2000, description=\"Soft cap for queued background jobs\"\n )\n JOB_QUEUE_RETAIN_COMPLETED: int = Field(\n default=3000, description=\"Maximum completed/failed jobs to retain in memory\"\n )\n JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(\n default=60, description=\"How often to run in-memory job cleanup\"\n )\n JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
|
||||||
|
if needle in text:
|
||||||
|
text = text.replace(needle, insert)
|
||||||
|
cfg.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# 2) scanner.py default scope = dataset-only
|
||||||
|
scanner = root / "backend/app/services/scanner.py"
|
||||||
|
text = scanner.read_text(encoding="utf-8")
|
||||||
|
text = text.replace(" scan_hunts: bool = True,", " scan_hunts: bool = False,")
|
||||||
|
text = text.replace(" scan_annotations: bool = True,", " scan_annotations: bool = False,")
|
||||||
|
text = text.replace(" scan_messages: bool = True,", " scan_messages: bool = False,")
|
||||||
|
scanner.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# 3) keywords.py defaults = dataset-only
|
||||||
|
kw = root / "backend/app/api/routes/keywords.py"
|
||||||
|
text = kw.read_text(encoding="utf-8")
|
||||||
|
text = text.replace(" scan_hunts: bool = True", " scan_hunts: bool = False")
|
||||||
|
text = text.replace(" scan_annotations: bool = True", " scan_annotations: bool = False")
|
||||||
|
text = text.replace(" scan_messages: bool = True", " scan_messages: bool = False")
|
||||||
|
kw.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# 4) job_queue.py dedupe + periodic cleanup
|
||||||
|
jq = root / "backend/app/services/job_queue.py"
|
||||||
|
text = jq.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
"from typing import Any, Callable, Coroutine, Optional\n",
|
||||||
|
"from typing import Any, Callable, Coroutine, Optional\n\nfrom app.config import settings\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n",
|
||||||
|
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n self._cleanup_task: asyncio.Task | None = None\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_old = ''' async def start(self):
|
||||||
|
if self._started:
|
||||||
|
return
|
||||||
|
self._started = True
|
||||||
|
for i in range(self._max_workers):
|
||||||
|
task = asyncio.create_task(self._worker(i))
|
||||||
|
self._workers.append(task)
|
||||||
|
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||||
|
'''
|
||||||
|
start_new = ''' async def start(self):
|
||||||
|
if self._started:
|
||||||
|
return
|
||||||
|
self._started = True
|
||||||
|
for i in range(self._max_workers):
|
||||||
|
task = asyncio.create_task(self._worker(i))
|
||||||
|
self._workers.append(task)
|
||||||
|
if not self._cleanup_task or self._cleanup_task.done():
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||||
|
'''
|
||||||
|
text = text.replace(start_old, start_new)
|
||||||
|
|
||||||
|
stop_old = ''' async def stop(self):
|
||||||
|
self._started = False
|
||||||
|
for w in self._workers:
|
||||||
|
w.cancel()
|
||||||
|
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||||
|
self._workers.clear()
|
||||||
|
logger.info("Job queue stopped")
|
||||||
|
'''
|
||||||
|
stop_new = ''' async def stop(self):
|
||||||
|
self._started = False
|
||||||
|
for w in self._workers:
|
||||||
|
w.cancel()
|
||||||
|
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||||
|
self._workers.clear()
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||||
|
self._cleanup_task = None
|
||||||
|
logger.info("Job queue stopped")
|
||||||
|
'''
|
||||||
|
text = text.replace(stop_old, stop_new)
|
||||||
|
|
||||||
|
submit_old = ''' def submit(self, job_type: JobType, **params) -> Job:
|
||||||
|
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||||
|
self._jobs[job.id] = job
|
||||||
|
self._queue.put_nowait(job.id)
|
||||||
|
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||||
|
return job
|
||||||
|
'''
|
||||||
|
submit_new = ''' def submit(self, job_type: JobType, **params) -> Job:
|
||||||
|
# Soft backpressure: prefer dedupe over queue amplification
|
||||||
|
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||||
|
if dedupe_job is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||||
|
)
|
||||||
|
return dedupe_job
|
||||||
|
|
||||||
|
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||||
|
logger.warning(
|
||||||
|
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||||
|
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||||
|
)
|
||||||
|
|
||||||
|
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||||
|
self._jobs[job.id] = job
|
||||||
|
self._queue.put_nowait(job.id)
|
||||||
|
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||||
|
return job
|
||||||
|
'''
|
||||||
|
text = text.replace(submit_old, submit_new)
|
||||||
|
|
||||||
|
insert_methods_after = " def get_job(self, job_id: str) -> Job | None:\n return self._jobs.get(job_id)\n"
|
||||||
|
new_methods = ''' def get_job(self, job_id: str) -> Job | None:
|
||||||
|
return self._jobs.get(job_id)
|
||||||
|
|
||||||
|
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||||
|
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||||
|
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||||
|
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||||
|
if not sig:
|
||||||
|
return None
|
||||||
|
for j in self._jobs.values():
|
||||||
|
if j.job_type != job_type:
|
||||||
|
continue
|
||||||
|
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||||
|
continue
|
||||||
|
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||||
|
if sig == other_sig:
|
||||||
|
return j
|
||||||
|
return None
|
||||||
|
'''
|
||||||
|
text = text.replace(insert_methods_after, new_methods)
|
||||||
|
|
||||||
|
cleanup_old = ''' def cleanup(self, max_age_seconds: float = 3600):
|
||||||
|
now = time.time()
|
||||||
|
to_remove = [
|
||||||
|
jid for jid, j in self._jobs.items()
|
||||||
|
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||||
|
and (now - j.created_at) > max_age_seconds
|
||||||
|
]
|
||||||
|
for jid in to_remove:
|
||||||
|
del self._jobs[jid]
|
||||||
|
if to_remove:
|
||||||
|
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||||
|
'''
|
||||||
|
cleanup_new = ''' def cleanup(self, max_age_seconds: float = 3600):
|
||||||
|
now = time.time()
|
||||||
|
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||||
|
to_remove = [
|
||||||
|
jid for jid, j in self._jobs.items()
|
||||||
|
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||||
|
terminal_jobs = sorted(
|
||||||
|
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||||
|
key=lambda j: j.created_at,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||||
|
to_remove.extend([j.id for j in overflow])
|
||||||
|
|
||||||
|
removed = 0
|
||||||
|
for jid in set(to_remove):
|
||||||
|
if jid in self._jobs:
|
||||||
|
del self._jobs[jid]
|
||||||
|
removed += 1
|
||||||
|
if removed:
|
||||||
|
logger.info(f"Cleaned up {removed} old jobs")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||||
|
while self._started:
|
||||||
|
try:
|
||||||
|
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
'''
|
||||||
|
text = text.replace(cleanup_old, cleanup_new)
|
||||||
|
|
||||||
|
jq.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# 5) NetworkMap polling backoff/jitter max wait
|
||||||
|
nm = root / "frontend/src/components/NetworkMap.tsx"
|
||||||
|
text = nm.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
" // Poll until ready, then re-fetch\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n }\n",
|
||||||
|
" // Poll until ready (exponential backoff), then re-fetch\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n throw new Error('Host inventory build timed out after 5 minutes');\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status every 2s until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n // still building or none (job may not have started yet) - keep polling\n } catch { if (cancelled) return false; }\n }\n };\n",
|
||||||
|
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status with exponential backoff until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n setError('Host inventory build timed out. Please retry.');\n return false;\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n // still building or none (job may not have started yet) - keep polling\n } catch {\n if (cancelled) return false;\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n }\n };\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
nm.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
print("Patched: config.py, scanner.py, keywords.py, job_queue.py, NetworkMap.tsx")
|
||||||
207
_apply_phase2_patch.py
Normal file
207
_apply_phase2_patch.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||||
|
|
||||||
|
# ---------- config.py ----------
|
||||||
|
cfg = root / "backend/app/config.py"
|
||||||
|
text = cfg.read_text(encoding="utf-8")
|
||||||
|
marker = " JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
|
||||||
|
add = marker + "\n # -- Startup throttling ------------------------------------------------\n STARTUP_WARMUP_MAX_HUNTS: int = Field(\n default=5, description=\"Max hunts to warm inventory cache for at startup\"\n )\n STARTUP_REPROCESS_MAX_DATASETS: int = Field(\n default=25, description=\"Max unprocessed datasets to enqueue at startup\"\n )\n\n # -- Network API scale guards -----------------------------------------\n NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(\n default=400, description=\"Hard cap for hosts returned by network subgraph endpoint\"\n )\n NETWORK_SUBGRAPH_MAX_EDGES: int = Field(\n default=3000, description=\"Hard cap for edges returned by network subgraph endpoint\"\n )\n"
|
||||||
|
if marker in text and "STARTUP_WARMUP_MAX_HUNTS" not in text:
|
||||||
|
text = text.replace(marker, add)
|
||||||
|
cfg.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# ---------- job_queue.py ----------
|
||||||
|
jq = root / "backend/app/services/job_queue.py"
|
||||||
|
text = jq.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# add helper methods after get_stats
|
||||||
|
anchor = " def get_stats(self) -> dict:\n by_status = {}\n for j in self._jobs.values():\n by_status[j.status.value] = by_status.get(j.status.value, 0) + 1\n return {\n \"total\": len(self._jobs),\n \"queued\": self._queue.qsize(),\n \"by_status\": by_status,\n \"workers\": self._max_workers,\n \"active_workers\": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),\n }\n"
|
||||||
|
if "def is_backlogged(" not in text:
|
||||||
|
insert = anchor + "\n def is_backlogged(self) -> bool:\n return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG\n\n def can_accept(self, reserve: int = 0) -> bool:\n return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG\n"
|
||||||
|
text = text.replace(anchor, insert)
|
||||||
|
|
||||||
|
jq.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# ---------- host_inventory.py keyset pagination ----------
|
||||||
|
hi = root / "backend/app/services/host_inventory.py"
|
||||||
|
text = hi.read_text(encoding="utf-8")
|
||||||
|
old = ''' batch_size = 5000
|
||||||
|
offset = 0
|
||||||
|
while True:
|
||||||
|
rr = await db.execute(
|
||||||
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds.id)
|
||||||
|
.order_by(DatasetRow.row_index)
|
||||||
|
.offset(offset).limit(batch_size)
|
||||||
|
)
|
||||||
|
rows = rr.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
new = ''' batch_size = 5000
|
||||||
|
last_row_index = -1
|
||||||
|
while True:
|
||||||
|
rr = await db.execute(
|
||||||
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds.id)
|
||||||
|
.where(DatasetRow.row_index > last_row_index)
|
||||||
|
.order_by(DatasetRow.row_index)
|
||||||
|
.limit(batch_size)
|
||||||
|
)
|
||||||
|
rows = rr.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
if old in text:
|
||||||
|
text = text.replace(old, new)
|
||||||
|
text = text.replace(" offset += batch_size\n if len(rows) < batch_size:\n break\n", " last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n")
|
||||||
|
hi.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# ---------- network.py add summary/subgraph + backpressure ----------
|
||||||
|
net = root / "backend/app/api/routes/network.py"
|
||||||
|
text = net.read_text(encoding="utf-8")
|
||||||
|
text = text.replace("from fastapi import APIRouter, Depends, HTTPException, Query", "from fastapi import APIRouter, Depends, HTTPException, Query")
|
||||||
|
if "from app.config import settings" not in text:
|
||||||
|
text = text.replace("from app.db import get_db\n", "from app.config import settings\nfrom app.db import get_db\n")
|
||||||
|
|
||||||
|
# add helpers and endpoints before inventory-status endpoint
|
||||||
|
if "def _build_summary" not in text:
|
||||||
|
helper_block = '''
|
||||||
|
|
||||||
|
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||||
|
hosts = inv.get("hosts", [])
|
||||||
|
conns = inv.get("connections", [])
|
||||||
|
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||||
|
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||||
|
return {
|
||||||
|
"stats": inv.get("stats", {}),
|
||||||
|
"top_hosts": [
|
||||||
|
{
|
||||||
|
"id": h.get("id"),
|
||||||
|
"hostname": h.get("hostname"),
|
||||||
|
"row_count": h.get("row_count", 0),
|
||||||
|
"ip_count": len(h.get("ips", [])),
|
||||||
|
"user_count": len(h.get("users", [])),
|
||||||
|
}
|
||||||
|
for h in top_hosts
|
||||||
|
],
|
||||||
|
"top_edges": top_edges,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||||
|
hosts = inv.get("hosts", [])
|
||||||
|
conns = inv.get("connections", [])
|
||||||
|
|
||||||
|
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||||
|
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||||
|
|
||||||
|
if node_id:
|
||||||
|
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||||
|
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||||
|
ids = {node_id}
|
||||||
|
for c in rel_edges:
|
||||||
|
ids.add(c.get("source"))
|
||||||
|
ids.add(c.get("target"))
|
||||||
|
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||||
|
else:
|
||||||
|
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||||
|
allowed = {h.get("id") for h in rel_hosts}
|
||||||
|
rel_edges = [
|
||||||
|
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||||
|
if c.get("source") in allowed and c.get("target") in allowed
|
||||||
|
][:max_edges]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hosts": rel_hosts,
|
||||||
|
"connections": rel_edges,
|
||||||
|
"stats": {
|
||||||
|
**inv.get("stats", {}),
|
||||||
|
"subgraph_hosts": len(rel_hosts),
|
||||||
|
"subgraph_connections": len(rel_edges),
|
||||||
|
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/summary")
|
||||||
|
async def get_inventory_summary(
|
||||||
|
hunt_id: str = Query(..., description="Hunt ID"),
|
||||||
|
top_n: int = Query(20, ge=1, le=200),
|
||||||
|
):
|
||||||
|
"""Return a lightweight summary view for large hunts."""
|
||||||
|
cached = inventory_cache.get(hunt_id)
|
||||||
|
if cached is None:
|
||||||
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||||
|
)
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return JSONResponse(status_code=202, content={"status": "building"})
|
||||||
|
return _build_summary(cached, top_n=top_n)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subgraph")
|
||||||
|
async def get_inventory_subgraph(
|
||||||
|
hunt_id: str = Query(..., description="Hunt ID"),
|
||||||
|
node_id: str | None = Query(None, description="Optional focal node"),
|
||||||
|
max_hosts: int = Query(200, ge=1, le=5000),
|
||||||
|
max_edges: int = Query(1500, ge=1, le=20000),
|
||||||
|
):
|
||||||
|
"""Return a bounded subgraph for scale-safe rendering."""
|
||||||
|
cached = inventory_cache.get(hunt_id)
|
||||||
|
if cached is None:
|
||||||
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||||
|
)
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return JSONResponse(status_code=202, content={"status": "building"})
|
||||||
|
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||||
|
'''
|
||||||
|
text = text.replace("\n\n@router.get(\"/inventory-status\")", helper_block + "\n\n@router.get(\"/inventory-status\")")
|
||||||
|
|
||||||
|
# add backpressure in host-inventory enqueue points
|
||||||
|
text = text.replace(
|
||||||
|
" if not inventory_cache.is_building(hunt_id):\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
|
||||||
|
" if not inventory_cache.is_building(hunt_id):\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
|
||||||
|
)
|
||||||
|
text = text.replace(
|
||||||
|
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
|
||||||
|
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
|
||||||
|
)
|
||||||
|
|
||||||
|
net.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# ---------- analysis.py backpressure on manual submit ----------
|
||||||
|
analysis = root / "backend/app/api/routes/analysis.py"
|
||||||
|
text = analysis.read_text(encoding="utf-8")
|
||||||
|
text = text.replace(
|
||||||
|
" job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}",
|
||||||
|
" if not job_queue.can_accept():\n raise HTTPException(status_code=429, detail=\"Job queue is busy. Retry shortly.\")\n job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}"
|
||||||
|
)
|
||||||
|
analysis.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
# ---------- main.py startup throttles ----------
|
||||||
|
main = root / "backend/app/main.py"
|
||||||
|
text = main.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
" for hid in hunt_ids:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if hunt_ids:\n logger.info(f\"Queued host inventory warm-up for {len(hunt_ids)} hunts\")",
|
||||||
|
" warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]\n for hid in warm_hunts:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if warm_hunts:\n logger.info(f\"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})\")"
|
||||||
|
)
|
||||||
|
|
||||||
|
text = text.replace(
|
||||||
|
" if unprocessed_ids:\n for ds_id in unprocessed_ids:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(unprocessed_ids))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()",
|
||||||
|
" if unprocessed_ids:\n to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]\n for ds_id in to_reprocess:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(to_reprocess))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()"
|
||||||
|
)
|
||||||
|
|
||||||
|
main.write_text(text, encoding="utf-8")
|
||||||
|
|
||||||
|
print("Patched Phase 2 files")
|
||||||
75
_aup_add_dataset_scope_ui.py
Normal file
75
_aup_add_dataset_scope_ui.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# default selection when hunt changes: first 3 datasets instead of all
|
||||||
|
old=''' datasets.list(0, 500, selectedHuntId).then(res => {
|
||||||
|
if (cancelled) return;
|
||||||
|
setDsList(res.datasets);
|
||||||
|
setSelectedDs(new Set(res.datasets.map(d => d.id)));
|
||||||
|
}).catch(() => {});
|
||||||
|
'''
|
||||||
|
new=''' datasets.list(0, 500, selectedHuntId).then(res => {
|
||||||
|
if (cancelled) return;
|
||||||
|
setDsList(res.datasets);
|
||||||
|
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
|
||||||
|
}).catch(() => {});
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('hunt-change dataset init block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
# insert dataset scope multi-select under hunt info
|
||||||
|
anchor=''' {!selectedHuntId && (
|
||||||
|
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||||
|
All datasets will be scanned if no hunt is selected
|
||||||
|
</Typography>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{/* Theme selector */}
|
||||||
|
'''
|
||||||
|
insert=''' {!selectedHuntId && (
|
||||||
|
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||||
|
Select a hunt to enable scoped scanning
|
||||||
|
</Typography>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
|
||||||
|
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId="aup-dataset-label"
|
||||||
|
multiple
|
||||||
|
value={Array.from(selectedDs)}
|
||||||
|
label="Datasets"
|
||||||
|
renderValue={(selected) => `${(selected as string[]).length} selected`}
|
||||||
|
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
|
||||||
|
>
|
||||||
|
{dsList.map(d => (
|
||||||
|
<MenuItem key={d.id} value={d.id}>
|
||||||
|
<Checkbox size="small" checked={selectedDs.has(d.id)} />
|
||||||
|
<Typography variant="body2" sx={{ ml: 0.5 }}>
|
||||||
|
{d.name} ({d.row_count.toLocaleString()} rows)
|
||||||
|
</Typography>
|
||||||
|
</MenuItem>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{selectedHuntId && dsList.length > 0 && (
|
||||||
|
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
|
||||||
|
</Stack>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{/* Theme selector */}
|
||||||
|
'''
|
||||||
|
if anchor not in t:
|
||||||
|
raise SystemExit('dataset scope anchor not found')
|
||||||
|
t=t.replace(anchor,insert)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added AUP dataset multi-select scoping and safer defaults')
|
||||||
182
_aup_add_host_user_to_hits.py
Normal file
182
_aup_add_host_user_to_hits.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# 1) Extend ScanHit dataclass
|
||||||
|
old='''@dataclass
|
||||||
|
class ScanHit:
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str # dataset_row | hunt | annotation | message
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
'''
|
||||||
|
new='''@dataclass
|
||||||
|
class ScanHit:
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str # dataset_row | hunt | annotation | message
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
hostname: str | None = None
|
||||||
|
username: str | None = None
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('ScanHit dataclass block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
# 2) Add helper to infer hostname/user from a row
|
||||||
|
insert_after='''BATCH_SIZE = 200
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanHit:
|
||||||
|
'''
|
||||||
|
helper='''BATCH_SIZE = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||||
|
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||||
|
if not data:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
host_keys = (
|
||||||
|
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||||
|
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||||
|
)
|
||||||
|
user_keys = (
|
||||||
|
'username', 'user_name', 'user', 'account_name',
|
||||||
|
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||||
|
)
|
||||||
|
|
||||||
|
def pick(keys):
|
||||||
|
for k in keys:
|
||||||
|
for actual_key, v in data.items():
|
||||||
|
if actual_key.lower() == k and v not in (None, ''):
|
||||||
|
return str(v)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return pick(host_keys), pick(user_keys)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanHit:
|
||||||
|
'''
|
||||||
|
if insert_after in t and '_infer_hostname_and_user' not in t:
|
||||||
|
t=t.replace(insert_after,helper)
|
||||||
|
|
||||||
|
# 3) Extend _match_text signature and ScanHit construction
|
||||||
|
old_sig=''' def _match_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
patterns: dict,
|
||||||
|
source_type: str,
|
||||||
|
source_id: str | int,
|
||||||
|
field_name: str,
|
||||||
|
hits: list[ScanHit],
|
||||||
|
row_index: int | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
new_sig=''' def _match_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
patterns: dict,
|
||||||
|
source_type: str,
|
||||||
|
source_id: str | int,
|
||||||
|
field_name: str,
|
||||||
|
hits: list[ScanHit],
|
||||||
|
row_index: int | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
hostname: str | None = None,
|
||||||
|
username: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
'''
|
||||||
|
if old_sig not in t:
|
||||||
|
raise SystemExit('_match_text signature not found')
|
||||||
|
t=t.replace(old_sig,new_sig)
|
||||||
|
|
||||||
|
old_hit=''' hits.append(ScanHit(
|
||||||
|
theme_name=theme_name,
|
||||||
|
theme_color=theme_color,
|
||||||
|
keyword=kw_value,
|
||||||
|
source_type=source_type,
|
||||||
|
source_id=source_id,
|
||||||
|
field=field_name,
|
||||||
|
matched_value=matched_preview,
|
||||||
|
row_index=row_index,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
))
|
||||||
|
'''
|
||||||
|
new_hit=''' hits.append(ScanHit(
|
||||||
|
theme_name=theme_name,
|
||||||
|
theme_color=theme_color,
|
||||||
|
keyword=kw_value,
|
||||||
|
source_type=source_type,
|
||||||
|
source_id=source_id,
|
||||||
|
field=field_name,
|
||||||
|
matched_value=matched_preview,
|
||||||
|
row_index=row_index,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
hostname=hostname,
|
||||||
|
username=username,
|
||||||
|
))
|
||||||
|
'''
|
||||||
|
if old_hit not in t:
|
||||||
|
raise SystemExit('ScanHit append block not found')
|
||||||
|
t=t.replace(old_hit,new_hit)
|
||||||
|
|
||||||
|
# 4) Pass inferred hostname/username in dataset scan path
|
||||||
|
old_call=''' for row in rows:
|
||||||
|
result.rows_scanned += 1
|
||||||
|
data = row.data or {}
|
||||||
|
for col_name, cell_value in data.items():
|
||||||
|
if cell_value is None:
|
||||||
|
continue
|
||||||
|
text = str(cell_value)
|
||||||
|
self._match_text(
|
||||||
|
text,
|
||||||
|
patterns,
|
||||||
|
"dataset_row",
|
||||||
|
row.id,
|
||||||
|
col_name,
|
||||||
|
result.hits,
|
||||||
|
row_index=row.row_index,
|
||||||
|
dataset_name=ds_name,
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
new_call=''' for row in rows:
|
||||||
|
result.rows_scanned += 1
|
||||||
|
data = row.data or {}
|
||||||
|
hostname, username = _infer_hostname_and_user(data)
|
||||||
|
for col_name, cell_value in data.items():
|
||||||
|
if cell_value is None:
|
||||||
|
continue
|
||||||
|
text = str(cell_value)
|
||||||
|
self._match_text(
|
||||||
|
text,
|
||||||
|
patterns,
|
||||||
|
"dataset_row",
|
||||||
|
row.id,
|
||||||
|
col_name,
|
||||||
|
result.hits,
|
||||||
|
row_index=row.row_index,
|
||||||
|
dataset_name=ds_name,
|
||||||
|
hostname=hostname,
|
||||||
|
username=username,
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if old_call not in t:
|
||||||
|
raise SystemExit('dataset _match_text call block not found')
|
||||||
|
t=t.replace(old_call,new_call)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated scanner hits with hostname+username context')
|
||||||
32
_aup_extend_scanhit_api.py
Normal file
32
_aup_extend_scanhit_api.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old='''class ScanHit(BaseModel):
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
'''
|
||||||
|
new='''class ScanHit(BaseModel):
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
hostname: str | None = None
|
||||||
|
username: str | None = None
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('ScanHit pydantic model block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('extended API ScanHit model with hostname+username')
|
||||||
21
_aup_extend_scanhit_frontend_type.py
Normal file
21
_aup_extend_scanhit_frontend_type.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old='''export interface ScanHit {
|
||||||
|
theme_name: string; theme_color: string; keyword: string;
|
||||||
|
source_type: string; source_id: string | number; field: string;
|
||||||
|
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new='''export interface ScanHit {
|
||||||
|
theme_name: string; theme_color: string; keyword: string;
|
||||||
|
source_type: string; source_id: string | number; field: string;
|
||||||
|
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||||
|
hostname?: string | null; username?: string | null;
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('frontend ScanHit interface block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('extended frontend ScanHit type with hostname+username')
|
||||||
57
_aup_keywords_scope_and_missing.py
Normal file
57
_aup_keywords_scope_and_missing.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# add fast guard against unscoped global dataset scans
|
||||||
|
insert_after='''async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):\n scanner = KeywordScanner(db)\n\n'''
|
||||||
|
if insert_after not in t:
|
||||||
|
raise SystemExit('run_scan header block not found')
|
||||||
|
if 'Select at least one dataset' not in t:
|
||||||
|
guard=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:\n raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")\n\n'''
|
||||||
|
t=t.replace(insert_after, insert_after+guard)
|
||||||
|
|
||||||
|
old=''' if missing:
|
||||||
|
missing_entries: list[dict] = []
|
||||||
|
for dataset_id in missing:
|
||||||
|
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
|
||||||
|
keyword_scan_cache.put(dataset_id, partial)
|
||||||
|
missing_entries.append({"result": partial, "built_at": None})
|
||||||
|
|
||||||
|
merged = _merge_cached_results(
|
||||||
|
cached_entries + missing_entries,
|
||||||
|
allowed_theme_names if body.theme_ids else None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": len(cached_entries) > 0,
|
||||||
|
"cache_status": "partial" if cached_entries else "miss",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new=''' if missing:
|
||||||
|
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||||
|
merged = _merge_cached_results(
|
||||||
|
cached_entries + [{"result": partial, "built_at": None}],
|
||||||
|
allowed_theme_names if body.theme_ids else None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": len(cached_entries) > 0,
|
||||||
|
"cache_status": "partial" if cached_entries else "miss",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('partial-cache missing block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('hardened keywords scan scope + optimized missing-cache path')
|
||||||
18
_aup_reduce_budget.py
Normal file
18
_aup_reduce_budget.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||||
|
default=300000,
|
||||||
|
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
new=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('SCANNER_MAX_ROWS_PER_SCAN block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('reduced SCANNER_MAX_ROWS_PER_SCAN default to 120000')
|
||||||
42
_aup_update_grid_columns.py
Normal file
42
_aup_update_grid_columns.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old='''const RESULT_COLUMNS: GridColDef[] = [
|
||||||
|
{
|
||||||
|
field: 'theme_name', headerName: 'Theme', width: 140,
|
||||||
|
renderCell: (params) => (
|
||||||
|
<Chip label={params.value} size="small"
|
||||||
|
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||||
|
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||||
|
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
|
||||||
|
{ field: 'field', headerName: 'Field', width: 130 },
|
||||||
|
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
|
||||||
|
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
|
||||||
|
];
|
||||||
|
'''
|
||||||
|
new='''const RESULT_COLUMNS: GridColDef[] = [
|
||||||
|
{
|
||||||
|
field: 'theme_name', headerName: 'Theme', width: 140,
|
||||||
|
renderCell: (params) => (
|
||||||
|
<Chip label={params.value} size="small"
|
||||||
|
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||||
|
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
|
||||||
|
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
|
||||||
|
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
|
||||||
|
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
|
||||||
|
{ field: 'field', headerName: 'Field', width: 130 },
|
||||||
|
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||||
|
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
|
||||||
|
];
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('RESULT_COLUMNS block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated AUP results grid columns with dataset/hostname/user/matched value focus')
|
||||||
40
_edit_aup.py
Normal file
40
_edit_aup.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
t=t.replace(' const [scanHunts, setScanHunts] = useState(true);',' const [scanHunts, setScanHunts] = useState(false);')
|
||||||
|
t=t.replace(' const [scanAnnotations, setScanAnnotations] = useState(true);',' const [scanAnnotations, setScanAnnotations] = useState(false);')
|
||||||
|
t=t.replace(' const [scanMessages, setScanMessages] = useState(true);',' const [scanMessages, setScanMessages] = useState(false);')
|
||||||
|
t=t.replace(' scan_messages: scanMessages,\n });',' scan_messages: scanMessages,\n prefer_cache: true,\n });')
|
||||||
|
# add cache chip in summary alert
|
||||||
|
old=''' {scanResult && (
|
||||||
|
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
|
||||||
|
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||||
|
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||||
|
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
'''
|
||||||
|
new=''' {scanResult && (
|
||||||
|
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
|
||||||
|
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||||
|
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||||
|
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||||
|
{scanResult.cache_status && (
|
||||||
|
<Chip
|
||||||
|
size="small"
|
||||||
|
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
|
||||||
|
sx={{ ml: 1, height: 20 }}
|
||||||
|
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
|
||||||
|
variant="outlined"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Alert>
|
||||||
|
)}
|
||||||
|
'''
|
||||||
|
if old in t:
|
||||||
|
t=t.replace(old,new)
|
||||||
|
else:
|
||||||
|
print('warning: summary block not replaced')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated AUPScanner.tsx')
|
||||||
36
_edit_client.py
Normal file
36
_edit_client.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
# Add HuntProgress interface after Hunt interface
|
||||||
|
if 'export interface HuntProgress' not in t:
|
||||||
|
insert = '''export interface HuntProgress {
|
||||||
|
hunt_id: string;
|
||||||
|
status: 'idle' | 'processing' | 'ready';
|
||||||
|
progress_percent: number;
|
||||||
|
dataset_total: number;
|
||||||
|
dataset_completed: number;
|
||||||
|
dataset_processing: number;
|
||||||
|
dataset_errors: number;
|
||||||
|
active_jobs: number;
|
||||||
|
queued_jobs: number;
|
||||||
|
network_status: 'none' | 'building' | 'ready';
|
||||||
|
stages: Record<string, any>;
|
||||||
|
}
|
||||||
|
|
||||||
|
'''
|
||||||
|
t=t.replace('export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n', 'export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n'+insert)
|
||||||
|
|
||||||
|
# Add hunts.progress method
|
||||||
|
if 'progress: (id: string)' not in t:
|
||||||
|
t=t.replace(" delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n};", " delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),\n};")
|
||||||
|
|
||||||
|
# Extend ScanResponse
|
||||||
|
if 'cache_used?: boolean' not in t:
|
||||||
|
t=t.replace('export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n}\n', 'export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n cache_used?: boolean; cache_status?: string; cached_at?: string | null;\n}\n')
|
||||||
|
|
||||||
|
# Extend keywords.scan opts
|
||||||
|
t=t.replace(' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n }) =>', ' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n prefer_cache?: boolean; force_rescan?: boolean;\n }) =>')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated client.ts')
|
||||||
20
_edit_config_reconcile.py
Normal file
20
_edit_config_reconcile.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
anchor=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||||
|
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
insert=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||||
|
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||||
|
)
|
||||||
|
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Mark stale queued/running processing tasks as failed on startup",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if anchor not in t:
|
||||||
|
raise SystemExit('startup anchor not found')
|
||||||
|
t=t.replace(anchor,insert)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated config with STARTUP_RECONCILE_STALE_TASKS')
|
||||||
39
_edit_datasets.py
Normal file
39
_edit_datasets.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'from app.services.scanner import keyword_scan_cache' not in t:
|
||||||
|
t=t.replace('from app.services.host_inventory import inventory_cache','from app.services.host_inventory import inventory_cache\nfrom app.services.scanner import keyword_scan_cache')
|
||||||
|
old='''@router.delete(
|
||||||
|
"/{dataset_id}",
|
||||||
|
summary="Delete a dataset",
|
||||||
|
)
|
||||||
|
async def delete_dataset(
|
||||||
|
dataset_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
repo = DatasetRepository(db)
|
||||||
|
deleted = await repo.delete_dataset(dataset_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||||
|
return {"message": "Dataset deleted", "id": dataset_id}
|
||||||
|
'''
|
||||||
|
new='''@router.delete(
|
||||||
|
"/{dataset_id}",
|
||||||
|
summary="Delete a dataset",
|
||||||
|
)
|
||||||
|
async def delete_dataset(
|
||||||
|
dataset_id: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
repo = DatasetRepository(db)
|
||||||
|
deleted = await repo.delete_dataset(dataset_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||||
|
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||||
|
return {"message": "Dataset deleted", "id": dataset_id}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('delete block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated datasets.py')
|
||||||
110
_edit_datasets_tasks.py
Normal file
110
_edit_datasets_tasks.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'ProcessingTask' not in t:
|
||||||
|
t=t.replace('from app.db.models import', 'from app.db.models import ProcessingTask\n# from app.db.models import')
|
||||||
|
t=t.replace('from app.services.scanner import keyword_scan_cache','from app.services.scanner import keyword_scan_cache')
|
||||||
|
# clean import replacement to proper single line
|
||||||
|
if '# from app.db.models import' in t:
|
||||||
|
t=t.replace('from app.db.models import ProcessingTask\n# from app.db.models import', 'from app.db.models import ProcessingTask')
|
||||||
|
|
||||||
|
old=''' # 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||||
|
job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("triage")
|
||||||
|
|
||||||
|
# 2. Anomaly detection (embedding-based outlier detection)
|
||||||
|
job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("anomaly")
|
||||||
|
|
||||||
|
# 3. AUP keyword scan
|
||||||
|
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("keyword_scan")
|
||||||
|
|
||||||
|
# 4. IOC extraction
|
||||||
|
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("ioc_extract")
|
||||||
|
|
||||||
|
# 5. Host inventory (network map) - requires hunt_id
|
||||||
|
if hunt_id:
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
jobs_queued.append("host_inventory")
|
||||||
|
'''
|
||||||
|
new=''' task_rows: list[ProcessingTask] = []
|
||||||
|
|
||||||
|
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||||
|
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("triage")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=triage_job.id,
|
||||||
|
stage="triage",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 2. Anomaly detection (embedding-based outlier detection)
|
||||||
|
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("anomaly")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=anomaly_job.id,
|
||||||
|
stage="anomaly",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 3. AUP keyword scan
|
||||||
|
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("keyword_scan")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=kw_job.id,
|
||||||
|
stage="keyword_scan",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 4. IOC extraction
|
||||||
|
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("ioc_extract")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=ioc_job.id,
|
||||||
|
stage="ioc_extract",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 5. Host inventory (network map) - requires hunt_id
|
||||||
|
if hunt_id:
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
jobs_queued.append("host_inventory")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=inv_job.id,
|
||||||
|
stage="host_inventory",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
if task_rows:
|
||||||
|
db.add_all(task_rows)
|
||||||
|
await db.flush()
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('queue block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated datasets upload queue + processing tasks')
|
||||||
254
_edit_hunts.py
Normal file
254
_edit_hunts.py
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
|
||||||
|
new='''"""API routes for hunt management."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import Hunt, Dataset
|
||||||
|
from app.services.job_queue import job_queue
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||||
|
|
||||||
|
|
||||||
|
class HuntCreate(BaseModel):
|
||||||
|
name: str = Field(..., max_length=256)
|
||||||
|
description: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class HuntUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class HuntResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str | None
|
||||||
|
status: str
|
||||||
|
owner_id: str | None
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
dataset_count: int = 0
|
||||||
|
hypothesis_count: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class HuntListResponse(BaseModel):
|
||||||
|
hunts: list[HuntResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class HuntProgressResponse(BaseModel):
|
||||||
|
hunt_id: str
|
||||||
|
status: str
|
||||||
|
progress_percent: float
|
||||||
|
dataset_total: int
|
||||||
|
dataset_completed: int
|
||||||
|
dataset_processing: int
|
||||||
|
dataset_errors: int
|
||||||
|
active_jobs: int
|
||||||
|
queued_jobs: int
|
||||||
|
network_status: str
|
||||||
|
stages: dict
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||||
|
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
hunt = Hunt(name=body.name, description=body.description)
|
||||||
|
db.add(hunt)
|
||||||
|
await db.flush()
|
||||||
|
return HuntResponse(
|
||||||
|
id=hunt.id,
|
||||||
|
name=hunt.name,
|
||||||
|
description=hunt.description,
|
||||||
|
status=hunt.status,
|
||||||
|
owner_id=hunt.owner_id,
|
||||||
|
created_at=hunt.created_at.isoformat(),
|
||||||
|
updated_at=hunt.updated_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=HuntListResponse, summary="List hunts")
|
||||||
|
async def list_hunts(
|
||||||
|
status: str | None = Query(None),
|
||||||
|
limit: int = Query(50, ge=1, le=500),
|
||||||
|
offset: int = Query(0, ge=0),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(Hunt.status == status)
|
||||||
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
hunts = result.scalars().all()
|
||||||
|
|
||||||
|
count_stmt = select(func.count(Hunt.id))
|
||||||
|
if status:
|
||||||
|
count_stmt = count_stmt.where(Hunt.status == status)
|
||||||
|
total = (await db.execute(count_stmt)).scalar_one()
|
||||||
|
|
||||||
|
return HuntListResponse(
|
||||||
|
hunts=[
|
||||||
|
HuntResponse(
|
||||||
|
id=h.id,
|
||||||
|
name=h.name,
|
||||||
|
description=h.description,
|
||||||
|
status=h.status,
|
||||||
|
owner_id=h.owner_id,
|
||||||
|
created_at=h.created_at.isoformat(),
|
||||||
|
updated_at=h.updated_at.isoformat(),
|
||||||
|
dataset_count=len(h.datasets) if h.datasets else 0,
|
||||||
|
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
|
||||||
|
)
|
||||||
|
for h in hunts
|
||||||
|
],
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
|
||||||
|
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||||
|
hunt = result.scalar_one_or_none()
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
return HuntResponse(
|
||||||
|
id=hunt.id,
|
||||||
|
name=hunt.name,
|
||||||
|
description=hunt.description,
|
||||||
|
status=hunt.status,
|
||||||
|
owner_id=hunt.owner_id,
|
||||||
|
created_at=hunt.created_at.isoformat(),
|
||||||
|
updated_at=hunt.updated_at.isoformat(),
|
||||||
|
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
|
||||||
|
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||||
|
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
hunt = await db.get(Hunt, hunt_id)
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
|
||||||
|
ds_rows = await db.execute(
|
||||||
|
select(Dataset.id, Dataset.processing_status)
|
||||||
|
.where(Dataset.hunt_id == hunt_id)
|
||||||
|
)
|
||||||
|
datasets = ds_rows.all()
|
||||||
|
dataset_ids = {row[0] for row in datasets}
|
||||||
|
|
||||||
|
dataset_total = len(datasets)
|
||||||
|
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||||
|
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||||
|
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||||
|
|
||||||
|
jobs = job_queue.list_jobs(limit=5000)
|
||||||
|
relevant_jobs = [
|
||||||
|
j for j in jobs
|
||||||
|
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||||
|
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||||
|
]
|
||||||
|
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||||
|
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||||
|
|
||||||
|
if inventory_cache.get(hunt_id) is not None:
|
||||||
|
network_status = "ready"
|
||||||
|
network_ratio = 1.0
|
||||||
|
elif inventory_cache.is_building(hunt_id):
|
||||||
|
network_status = "building"
|
||||||
|
network_ratio = 0.5
|
||||||
|
else:
|
||||||
|
network_status = "none"
|
||||||
|
network_ratio = 0.0
|
||||||
|
|
||||||
|
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||||
|
progress_percent = round(overall_ratio * 100.0, 1)
|
||||||
|
|
||||||
|
status = "ready"
|
||||||
|
if dataset_total == 0:
|
||||||
|
status = "idle"
|
||||||
|
elif progress_percent < 100:
|
||||||
|
status = "processing"
|
||||||
|
|
||||||
|
stages = {
|
||||||
|
"datasets": {
|
||||||
|
"total": dataset_total,
|
||||||
|
"completed": dataset_completed,
|
||||||
|
"processing": dataset_processing,
|
||||||
|
"errors": dataset_errors,
|
||||||
|
"percent": round(dataset_ratio * 100.0, 1),
|
||||||
|
},
|
||||||
|
"network": {
|
||||||
|
"status": network_status,
|
||||||
|
"percent": round(network_ratio * 100.0, 1),
|
||||||
|
},
|
||||||
|
"jobs": {
|
||||||
|
"active": active_jobs,
|
||||||
|
"queued": queued_jobs,
|
||||||
|
"total_seen": len(relevant_jobs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return HuntProgressResponse(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
status=status,
|
||||||
|
progress_percent=progress_percent,
|
||||||
|
dataset_total=dataset_total,
|
||||||
|
dataset_completed=dataset_completed,
|
||||||
|
dataset_processing=dataset_processing,
|
||||||
|
dataset_errors=dataset_errors,
|
||||||
|
active_jobs=active_jobs,
|
||||||
|
queued_jobs=queued_jobs,
|
||||||
|
network_status=network_status,
|
||||||
|
stages=stages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||||
|
async def update_hunt(
|
||||||
|
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||||
|
hunt = result.scalar_one_or_none()
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
if body.name is not None:
|
||||||
|
hunt.name = body.name
|
||||||
|
if body.description is not None:
|
||||||
|
hunt.description = body.description
|
||||||
|
if body.status is not None:
|
||||||
|
hunt.status = body.status
|
||||||
|
await db.flush()
|
||||||
|
return HuntResponse(
|
||||||
|
id=hunt.id,
|
||||||
|
name=hunt.name,
|
||||||
|
description=hunt.description,
|
||||||
|
status=hunt.status,
|
||||||
|
owner_id=hunt.owner_id,
|
||||||
|
created_at=hunt.created_at.isoformat(),
|
||||||
|
updated_at=hunt.updated_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{hunt_id}", summary="Delete a hunt")
|
||||||
|
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||||
|
hunt = result.scalar_one_or_none()
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
await db.delete(hunt)
|
||||||
|
return {"message": "Hunt deleted", "id": hunt_id}
|
||||||
|
'''
|
||||||
|
p.write_text(new,encoding='utf-8')
|
||||||
|
print('updated hunts.py')
|
||||||
102
_edit_hunts_progress_tasks.py
Normal file
102
_edit_hunts_progress_tasks.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'ProcessingTask' not in t:
|
||||||
|
t=t.replace('from app.db.models import Hunt, Dataset','from app.db.models import Hunt, Dataset, ProcessingTask')
|
||||||
|
|
||||||
|
old=''' jobs = job_queue.list_jobs(limit=5000)
|
||||||
|
relevant_jobs = [
|
||||||
|
j for j in jobs
|
||||||
|
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||||
|
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||||
|
]
|
||||||
|
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||||
|
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||||
|
|
||||||
|
if inventory_cache.get(hunt_id) is not None:
|
||||||
|
'''
|
||||||
|
new=''' jobs = job_queue.list_jobs(limit=5000)
|
||||||
|
relevant_jobs = [
|
||||||
|
j for j in jobs
|
||||||
|
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||||
|
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||||
|
]
|
||||||
|
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||||
|
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||||
|
|
||||||
|
task_rows = await db.execute(
|
||||||
|
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||||
|
.where(ProcessingTask.hunt_id == hunt_id)
|
||||||
|
)
|
||||||
|
tasks = task_rows.all()
|
||||||
|
|
||||||
|
task_total = len(tasks)
|
||||||
|
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||||
|
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||||
|
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||||
|
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||||
|
|
||||||
|
active_jobs = max(active_jobs_mem, task_running)
|
||||||
|
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||||
|
|
||||||
|
stage_rollup: dict[str, dict] = {}
|
||||||
|
for stage, status, progress in tasks:
|
||||||
|
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||||
|
bucket["total"] += 1
|
||||||
|
if status in ("completed", "failed", "cancelled"):
|
||||||
|
bucket["done"] += 1
|
||||||
|
elif status == "running":
|
||||||
|
bucket["running"] += 1
|
||||||
|
elif status == "queued":
|
||||||
|
bucket["queued"] += 1
|
||||||
|
bucket["progress_sum"] += float(progress or 0.0)
|
||||||
|
|
||||||
|
for stage_name, bucket in stage_rollup.items():
|
||||||
|
total = max(1, bucket["total"])
|
||||||
|
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||||
|
|
||||||
|
if inventory_cache.get(hunt_id) is not None:
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('job block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
old2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||||
|
progress_percent = round(overall_ratio * 100.0, 1)
|
||||||
|
'''
|
||||||
|
new2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||||
|
if task_ratio is None:
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||||
|
else:
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||||
|
progress_percent = round(overall_ratio * 100.0, 1)
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('ratio block not found')
|
||||||
|
t=t.replace(old2,new2)
|
||||||
|
|
||||||
|
old3=''' "jobs": {
|
||||||
|
"active": active_jobs,
|
||||||
|
"queued": queued_jobs,
|
||||||
|
"total_seen": len(relevant_jobs),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new3=''' "jobs": {
|
||||||
|
"active": active_jobs,
|
||||||
|
"queued": queued_jobs,
|
||||||
|
"total_seen": len(relevant_jobs),
|
||||||
|
"task_total": task_total,
|
||||||
|
"task_done": task_done,
|
||||||
|
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||||
|
},
|
||||||
|
"task_stages": stage_rollup,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old3 not in t:
|
||||||
|
raise SystemExit('stages jobs block not found')
|
||||||
|
t=t.replace(old3,new3)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated hunt progress to merge persistent processing tasks')
|
||||||
46
_edit_job_queue.py
Normal file
46
_edit_job_queue.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old='''async def _handle_keyword_scan(job: Job):
|
||||||
|
"""AUP keyword scan handler."""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.services.scanner import KeywordScanner
|
||||||
|
|
||||||
|
dataset_id = job.params.get("dataset_id")
|
||||||
|
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
|
||||||
|
hits = result.get("total_hits", 0)
|
||||||
|
job.message = f"Keyword scan complete: {hits} hits"
|
||||||
|
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||||
|
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||||
|
'''
|
||||||
|
new='''async def _handle_keyword_scan(job: Job):
|
||||||
|
"""AUP keyword scan handler."""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||||
|
|
||||||
|
dataset_id = job.params.get("dataset_id")
|
||||||
|
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
|
||||||
|
# Cache dataset-only result for fast API reuse
|
||||||
|
if dataset_id:
|
||||||
|
keyword_scan_cache.put(dataset_id, result)
|
||||||
|
|
||||||
|
hits = result.get("total_hits", 0)
|
||||||
|
job.message = f"Keyword scan complete: {hits} hits"
|
||||||
|
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||||
|
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('target block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated job_queue keyword scan handler')
|
||||||
13
_edit_jobqueue_reconcile.py
Normal file
13
_edit_jobqueue_reconcile.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
marker='''def register_all_handlers():
|
||||||
|
"""Register all job handlers and completion callbacks."""
|
||||||
|
'''
|
||||||
|
ins='''\n\nasync def reconcile_stale_processing_tasks() -> int:\n """Mark queued/running processing tasks from prior runs as failed."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n now = datetime.now(timezone.utc)\n async with async_session_factory() as db:\n result = await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.status.in_([\"queued\", \"running\"]))\n .values(\n status=\"failed\",\n error=\"Recovered after service restart before task completion\",\n message=\"Recovered stale task after restart\",\n completed_at=now,\n )\n )\n await db.commit()\n updated = int(result.rowcount or 0)\n\n if updated:\n logger.warning(\n \"Reconciled %d stale processing tasks (queued/running -> failed) during startup\",\n updated,\n )\n return updated\n except Exception as e:\n logger.warning(f\"Failed to reconcile stale processing tasks: {e}\")\n return 0\n\n\n'''
|
||||||
|
if ins.strip() not in t:
|
||||||
|
if marker not in t:
|
||||||
|
raise SystemExit('register marker not found')
|
||||||
|
t=t.replace(marker,ins+marker)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added reconcile_stale_processing_tasks to job_queue')
|
||||||
64
_edit_jobqueue_sync.py
Normal file
64
_edit_jobqueue_sync.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
ins='''\n\nasync def _sync_processing_task(job: Job):\n """Persist latest job state into processing_tasks (if linked by job_id)."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n values = {\n "status": job.status.value,\n "progress": float(job.progress),\n "message": job.message,\n "error": job.error,\n }\n if job.started_at:\n values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)\n if job.completed_at:\n values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)\n\n async with async_session_factory() as db:\n await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.job_id == job.id)\n .values(**values)\n )\n await db.commit()\n except Exception as e:\n logger.warning(f"Failed to sync processing task for job {job.id}: {e}")\n'''
|
||||||
|
marker='\n\n# -- Singleton + job handlers --\n'
|
||||||
|
if ins.strip() not in t:
|
||||||
|
t=t.replace(marker, ins+marker)
|
||||||
|
|
||||||
|
old=''' job.status = JobStatus.RUNNING
|
||||||
|
job.started_at = time.time()
|
||||||
|
job.message = "Running..."
|
||||||
|
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
'''
|
||||||
|
new=''' job.status = JobStatus.RUNNING
|
||||||
|
job.started_at = time.time()
|
||||||
|
if job.progress <= 0:
|
||||||
|
job.progress = 5.0
|
||||||
|
job.message = "Running..."
|
||||||
|
await _sync_processing_task(job)
|
||||||
|
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('worker running block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
old2=''' job.completed_at = time.time()
|
||||||
|
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||||
|
except Exception as e:
|
||||||
|
if not job.is_cancelled:
|
||||||
|
job.status = JobStatus.FAILED
|
||||||
|
job.error = str(e)
|
||||||
|
job.message = f"Failed: {e}"
|
||||||
|
job.completed_at = time.time()
|
||||||
|
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# Fire completion callbacks
|
||||||
|
'''
|
||||||
|
new2=''' job.completed_at = time.time()
|
||||||
|
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||||
|
except Exception as e:
|
||||||
|
if not job.is_cancelled:
|
||||||
|
job.status = JobStatus.FAILED
|
||||||
|
job.error = str(e)
|
||||||
|
job.message = f"Failed: {e}"
|
||||||
|
job.completed_at = time.time()
|
||||||
|
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
if job.is_cancelled and not job.completed_at:
|
||||||
|
job.completed_at = time.time()
|
||||||
|
|
||||||
|
await _sync_processing_task(job)
|
||||||
|
|
||||||
|
# Fire completion callbacks
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('worker completion block not found')
|
||||||
|
t=t.replace(old2,new2)
|
||||||
|
|
||||||
|
p.write_text(t, encoding='utf-8')
|
||||||
|
print('updated job_queue persistent task syncing')
|
||||||
39
_edit_jobqueue_triage_task.py
Normal file
39
_edit_jobqueue_triage_task.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' if hunt_id:
|
||||||
|
job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||||
|
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||||
|
except Exception as e:
|
||||||
|
'''
|
||||||
|
new=''' if hunt_id:
|
||||||
|
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.db.models import ProcessingTask
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
existing = await db.execute(
|
||||||
|
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||||
|
)
|
||||||
|
if existing.first() is None:
|
||||||
|
db.add(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
job_id=hp_job.id,
|
||||||
|
stage="host_profile",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
await db.commit()
|
||||||
|
except Exception as persist_err:
|
||||||
|
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||||
|
|
||||||
|
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||||
|
except Exception as e:
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('triage chain block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated triage chain to persist host_profile task row')
|
||||||
321
_edit_keywords.py
Normal file
321
_edit_keywords.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||||
|
new_text='''"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import KeywordTheme, Keyword
|
||||||
|
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||||
|
|
||||||
|
|
||||||
|
class ThemeCreate(BaseModel):
|
||||||
|
name: str = Field(..., min_length=1, max_length=128)
|
||||||
|
color: str = Field(default="#9e9e9e", max_length=16)
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ThemeUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
color: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordOut(BaseModel):
|
||||||
|
id: int
|
||||||
|
theme_id: str
|
||||||
|
value: str
|
||||||
|
is_regex: bool
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class ThemeOut(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
color: str
|
||||||
|
enabled: bool
|
||||||
|
is_builtin: bool
|
||||||
|
created_at: str
|
||||||
|
keyword_count: int
|
||||||
|
keywords: list[KeywordOut]
|
||||||
|
|
||||||
|
|
||||||
|
class ThemeListResponse(BaseModel):
|
||||||
|
themes: list[ThemeOut]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordCreate(BaseModel):
|
||||||
|
value: str = Field(..., min_length=1, max_length=256)
|
||||||
|
is_regex: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordBulkCreate(BaseModel):
|
||||||
|
values: list[str] = Field(..., min_items=1)
|
||||||
|
is_regex: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ScanRequest(BaseModel):
|
||||||
|
dataset_ids: list[str] | None = None
|
||||||
|
theme_ids: list[str] | None = None
|
||||||
|
scan_hunts: bool = False
|
||||||
|
scan_annotations: bool = False
|
||||||
|
scan_messages: bool = False
|
||||||
|
prefer_cache: bool = True
|
||||||
|
force_rescan: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ScanHit(BaseModel):
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScanResponse(BaseModel):
|
||||||
|
total_hits: int
|
||||||
|
hits: list[ScanHit]
|
||||||
|
themes_scanned: int
|
||||||
|
keywords_scanned: int
|
||||||
|
rows_scanned: int
|
||||||
|
cache_used: bool = False
|
||||||
|
cache_status: str = "miss"
|
||||||
|
cached_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||||
|
return ThemeOut(
|
||||||
|
id=t.id,
|
||||||
|
name=t.name,
|
||||||
|
color=t.color,
|
||||||
|
enabled=t.enabled,
|
||||||
|
is_builtin=t.is_builtin,
|
||||||
|
created_at=t.created_at.isoformat(),
|
||||||
|
keyword_count=len(t.keywords),
|
||||||
|
keywords=[
|
||||||
|
KeywordOut(
|
||||||
|
id=k.id,
|
||||||
|
theme_id=k.theme_id,
|
||||||
|
value=k.value,
|
||||||
|
is_regex=k.is_regex,
|
||||||
|
created_at=k.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
for k in t.keywords
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||||
|
hits: list[dict] = []
|
||||||
|
total_rows = 0
|
||||||
|
cached_at: str | None = None
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
result = entry["result"]
|
||||||
|
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||||
|
if entry.get("built_at"):
|
||||||
|
if not cached_at or entry["built_at"] > cached_at:
|
||||||
|
cached_at = entry["built_at"]
|
||||||
|
for h in result.get("hits", []):
|
||||||
|
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||||
|
continue
|
||||||
|
hits.append(h)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_hits": len(hits),
|
||||||
|
"hits": hits,
|
||||||
|
"rows_scanned": total_rows,
|
||||||
|
"cached_at": cached_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/themes", response_model=ThemeListResponse)
|
||||||
|
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||||
|
themes = result.scalars().all()
|
||||||
|
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||||
|
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||||
|
if exists:
|
||||||
|
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||||
|
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||||
|
db.add(theme)
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
return _theme_to_out(theme)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||||
|
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||||
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
|
if not theme:
|
||||||
|
raise HTTPException(404, "Theme not found")
|
||||||
|
if body.name is not None:
|
||||||
|
dup = await db.scalar(
|
||||||
|
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||||
|
)
|
||||||
|
if dup:
|
||||||
|
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||||
|
theme.name = body.name
|
||||||
|
if body.color is not None:
|
||||||
|
theme.color = body.color
|
||||||
|
if body.enabled is not None:
|
||||||
|
theme.enabled = body.enabled
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
return _theme_to_out(theme)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/themes/{theme_id}", status_code=204)
|
||||||
|
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
|
if not theme:
|
||||||
|
raise HTTPException(404, "Theme not found")
|
||||||
|
await db.delete(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||||
|
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
|
if not theme:
|
||||||
|
raise HTTPException(404, "Theme not found")
|
||||||
|
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
|
||||||
|
db.add(kw)
|
||||||
|
await db.flush()
|
||||||
|
await db.refresh(kw)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
return KeywordOut(
|
||||||
|
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||||
|
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||||
|
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
|
if not theme:
|
||||||
|
raise HTTPException(404, "Theme not found")
|
||||||
|
added = 0
|
||||||
|
for val in body.values:
|
||||||
|
val = val.strip()
|
||||||
|
if not val:
|
||||||
|
continue
|
||||||
|
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||||
|
added += 1
|
||||||
|
await db.flush()
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
return {"added": added, "theme_id": theme_id}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||||
|
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||||
|
kw = await db.get(Keyword, keyword_id)
|
||||||
|
if not kw:
|
||||||
|
raise HTTPException(404, "Keyword not found")
|
||||||
|
await db.delete(kw)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/scan", response_model=ScanResponse)
|
||||||
|
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
|
||||||
|
can_use_cache = (
|
||||||
|
body.prefer_cache
|
||||||
|
and not body.force_rescan
|
||||||
|
and bool(body.dataset_ids)
|
||||||
|
and not body.scan_hunts
|
||||||
|
and not body.scan_annotations
|
||||||
|
and not body.scan_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_use_cache:
|
||||||
|
themes = await scanner._load_themes(body.theme_ids)
|
||||||
|
allowed_theme_names = {t.name for t in themes}
|
||||||
|
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||||
|
|
||||||
|
cached_entries: list[dict] = []
|
||||||
|
missing: list[str] = []
|
||||||
|
for dataset_id in (body.dataset_ids or []):
|
||||||
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if not entry:
|
||||||
|
missing.append(dataset_id)
|
||||||
|
continue
|
||||||
|
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||||
|
|
||||||
|
if not missing and cached_entries:
|
||||||
|
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await scanner.scan(
|
||||||
|
dataset_ids=body.dataset_ids,
|
||||||
|
theme_ids=body.theme_ids,
|
||||||
|
scan_hunts=body.scan_hunts,
|
||||||
|
scan_annotations=body.scan_annotations,
|
||||||
|
scan_messages=body.scan_messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/scan/quick", response_model=ScanResponse)
|
||||||
|
async def quick_scan(
|
||||||
|
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if entry is not None:
|
||||||
|
result = entry.result
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": entry.built_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
keyword_scan_cache.put(dataset_id, result)
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
p.write_text(new_text,encoding='utf-8')
|
||||||
|
print('updated keywords.py')
|
||||||
31
_edit_main_reconcile.py
Normal file
31
_edit_main_reconcile.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/main.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' # Start job queue
|
||||||
|
from app.services.job_queue import job_queue, register_all_handlers, JobType
|
||||||
|
register_all_handlers()
|
||||||
|
await job_queue.start()
|
||||||
|
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||||
|
'''
|
||||||
|
new=''' # Start job queue
|
||||||
|
from app.services.job_queue import (
|
||||||
|
job_queue,
|
||||||
|
register_all_handlers,
|
||||||
|
reconcile_stale_processing_tasks,
|
||||||
|
JobType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if settings.STARTUP_RECONCILE_STALE_TASKS:
|
||||||
|
reconciled = await reconcile_stale_processing_tasks()
|
||||||
|
if reconciled:
|
||||||
|
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
|
||||||
|
|
||||||
|
register_all_handlers()
|
||||||
|
await job_queue.start()
|
||||||
|
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('startup queue block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('wired startup reconciliation in main lifespan')
|
||||||
45
_edit_models_processing.py
Normal file
45
_edit_models_processing.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/db/models.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'class ProcessingTask(Base):' in t:
|
||||||
|
print('processing task model already exists')
|
||||||
|
raise SystemExit(0)
|
||||||
|
insert='''
|
||||||
|
|
||||||
|
# -- Persistent Processing Tasks (Phase 2) ---
|
||||||
|
|
||||||
|
class ProcessingTask(Base):
|
||||||
|
__tablename__ = "processing_tasks"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||||
|
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
|
||||||
|
)
|
||||||
|
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
|
||||||
|
)
|
||||||
|
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||||
|
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
|
||||||
|
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
|
||||||
|
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
# insert before Playbook section
|
||||||
|
marker='\n\n# -- Playbook / Investigation Templates (Feature 3) ---\n'
|
||||||
|
if marker not in t:
|
||||||
|
raise SystemExit('marker not found for insertion')
|
||||||
|
t=t.replace(marker, insert+marker)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added ProcessingTask model')
|
||||||
59
_edit_networkmap_hit.py
Normal file
59
_edit_networkmap_hit.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
insert='''
|
||||||
|
function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean {
|
||||||
|
const fontSize = Math.max(9, Math.round(12 / vp.scale));
|
||||||
|
const approxCharW = Math.max(5, fontSize * 0.58);
|
||||||
|
const line1 = node.label || '';
|
||||||
|
const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : '';
|
||||||
|
const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0);
|
||||||
|
const px = 5, py = 2;
|
||||||
|
const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2;
|
||||||
|
const lx = node.x, ly = node.y - node.radius - 6;
|
||||||
|
const rx = lx - tw / 2 - px;
|
||||||
|
const ry = ly - totalH;
|
||||||
|
const rw = tw + px * 2;
|
||||||
|
const rh = totalH;
|
||||||
|
return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh);
|
||||||
|
}
|
||||||
|
|
||||||
|
'''
|
||||||
|
if 'function isPointOnNodeLabel' not in t:
|
||||||
|
t=t.replace('// == Hit-test =============================================================\n', '// == Hit-test =============================================================\n'+insert)
|
||||||
|
|
||||||
|
old='''function hitTest(
|
||||||
|
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
|
||||||
|
): GNode | null {
|
||||||
|
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
|
||||||
|
for (const n of graph.nodes) {
|
||||||
|
const dx = n.x - wx, dy = n.y - wy;
|
||||||
|
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new='''function hitTest(
|
||||||
|
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
|
||||||
|
): GNode | null {
|
||||||
|
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
|
||||||
|
|
||||||
|
// Node-circle hit has priority
|
||||||
|
for (const n of graph.nodes) {
|
||||||
|
const dx = n.x - wx, dy = n.y - wy;
|
||||||
|
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then label hit (so clicking text works too)
|
||||||
|
for (const n of graph.nodes) {
|
||||||
|
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('hitTest block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated NetworkMap hit-test for labels')
|
||||||
272
_edit_scanner.py
Normal file
272
_edit_scanner.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
p = Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||||
|
text = p.read_text(encoding='utf-8')
|
||||||
|
new_text = '''"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||||
|
messages for keyword matches.
|
||||||
|
|
||||||
|
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||||
|
across SQLite / PostgreSQL and to provide per-cell match context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models import (
|
||||||
|
KeywordTheme,
|
||||||
|
DatasetRow,
|
||||||
|
Dataset,
|
||||||
|
Hunt,
|
||||||
|
Annotation,
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BATCH_SIZE = 200
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanHit:
|
||||||
|
theme_name: str
|
||||||
|
theme_color: str
|
||||||
|
keyword: str
|
||||||
|
source_type: str # dataset_row | hunt | annotation | message
|
||||||
|
source_id: str | int
|
||||||
|
field: str
|
||||||
|
matched_value: str
|
||||||
|
row_index: int | None = None
|
||||||
|
dataset_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanResult:
|
||||||
|
total_hits: int = 0
|
||||||
|
hits: list[ScanHit] = field(default_factory=list)
|
||||||
|
themes_scanned: int = 0
|
||||||
|
keywords_scanned: int = 0
|
||||||
|
rows_scanned: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeywordScanCacheEntry:
|
||||||
|
dataset_id: str
|
||||||
|
result: dict
|
||||||
|
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordScanCache:
|
||||||
|
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||||
|
|
||||||
|
This enables fast-path reads when users run AUP scans against datasets that
|
||||||
|
were already scanned during upload pipeline processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||||
|
|
||||||
|
def put(self, dataset_id: str, result: dict):
|
||||||
|
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||||
|
|
||||||
|
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||||
|
return self._entries.get(dataset_id)
|
||||||
|
|
||||||
|
def invalidate_dataset(self, dataset_id: str):
|
||||||
|
self._entries.pop(dataset_id, None)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._entries.clear()
|
||||||
|
|
||||||
|
|
||||||
|
keyword_scan_cache = KeywordScanCache()
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordScanner:
|
||||||
|
"""Scans multiple data sources for keyword/regex matches."""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
# Public API
|
||||||
|
|
||||||
|
async def scan(
|
||||||
|
self,
|
||||||
|
dataset_ids: list[str] | None = None,
|
||||||
|
theme_ids: list[str] | None = None,
|
||||||
|
scan_hunts: bool = False,
|
||||||
|
scan_annotations: bool = False,
|
||||||
|
scan_messages: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||||
|
# Load themes + keywords
|
||||||
|
themes = await self._load_themes(theme_ids)
|
||||||
|
if not themes:
|
||||||
|
return ScanResult().__dict__
|
||||||
|
|
||||||
|
# Pre-compile patterns per theme
|
||||||
|
patterns = self._compile_patterns(themes)
|
||||||
|
result = ScanResult(
|
||||||
|
themes_scanned=len(themes),
|
||||||
|
keywords_scanned=sum(len(kws) for kws in patterns.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scan dataset rows
|
||||||
|
await self._scan_datasets(patterns, result, dataset_ids)
|
||||||
|
|
||||||
|
# Scan hunts
|
||||||
|
if scan_hunts:
|
||||||
|
await self._scan_hunts(patterns, result)
|
||||||
|
|
||||||
|
# Scan annotations
|
||||||
|
if scan_annotations:
|
||||||
|
await self._scan_annotations(patterns, result)
|
||||||
|
|
||||||
|
# Scan messages
|
||||||
|
if scan_messages:
|
||||||
|
await self._scan_messages(patterns, result)
|
||||||
|
|
||||||
|
result.total_hits = len(result.hits)
|
||||||
|
return {
|
||||||
|
"total_hits": result.total_hits,
|
||||||
|
"hits": [h.__dict__ for h in result.hits],
|
||||||
|
"themes_scanned": result.themes_scanned,
|
||||||
|
"keywords_scanned": result.keywords_scanned,
|
||||||
|
"rows_scanned": result.rows_scanned,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Internal
|
||||||
|
|
||||||
|
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||||
|
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||||
|
if theme_ids:
|
||||||
|
q = q.where(KeywordTheme.id.in_(theme_ids))
|
||||||
|
result = await self.db.execute(q)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
def _compile_patterns(
|
||||||
|
self, themes: list[KeywordTheme]
|
||||||
|
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
|
||||||
|
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
|
||||||
|
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
|
||||||
|
for theme in themes:
|
||||||
|
key = (theme.id, theme.name, theme.color)
|
||||||
|
compiled = []
|
||||||
|
for kw in theme.keywords:
|
||||||
|
try:
|
||||||
|
if kw.is_regex:
|
||||||
|
pat = re.compile(kw.value, re.IGNORECASE)
|
||||||
|
else:
|
||||||
|
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
|
||||||
|
compiled.append((kw.value, pat))
|
||||||
|
except re.error:
|
||||||
|
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
|
||||||
|
kw.value, theme.name)
|
||||||
|
patterns[key] = compiled
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _match_text(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
patterns: dict,
|
||||||
|
source_type: str,
|
||||||
|
source_id: str | int,
|
||||||
|
field_name: str,
|
||||||
|
hits: list[ScanHit],
|
||||||
|
row_index: int | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Check text against all compiled patterns, append hits."""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||||
|
for kw_value, pat in keyword_patterns:
|
||||||
|
if pat.search(text):
|
||||||
|
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||||
|
hits.append(ScanHit(
|
||||||
|
theme_name=theme_name,
|
||||||
|
theme_color=theme_color,
|
||||||
|
keyword=kw_value,
|
||||||
|
source_type=source_type,
|
||||||
|
source_id=source_id,
|
||||||
|
field=field_name,
|
||||||
|
matched_value=matched_preview,
|
||||||
|
row_index=row_index,
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _scan_datasets(
|
||||||
|
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||||
|
) -> None:
|
||||||
|
"""Scan dataset rows in batches."""
|
||||||
|
ds_q = select(Dataset.id, Dataset.name)
|
||||||
|
if dataset_ids:
|
||||||
|
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||||
|
ds_result = await self.db.execute(ds_q)
|
||||||
|
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||||
|
|
||||||
|
if not ds_map:
|
||||||
|
return
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
row_q_base = select(DatasetRow).where(
|
||||||
|
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||||
|
).order_by(DatasetRow.id)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
rows_result = await self.db.execute(
|
||||||
|
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||||
|
)
|
||||||
|
rows = rows_result.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.rows_scanned += 1
|
||||||
|
data = row.data or {}
|
||||||
|
for col_name, cell_value in data.items():
|
||||||
|
if cell_value is None:
|
||||||
|
continue
|
||||||
|
text = str(cell_value)
|
||||||
|
self._match_text(
|
||||||
|
text, patterns, "dataset_row", row.id,
|
||||||
|
col_name, result.hits,
|
||||||
|
row_index=row.row_index,
|
||||||
|
dataset_name=ds_map.get(row.dataset_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
offset += BATCH_SIZE
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if len(rows) < BATCH_SIZE:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||||
|
"""Scan hunt names and descriptions."""
|
||||||
|
hunts_result = await self.db.execute(select(Hunt))
|
||||||
|
for hunt in hunts_result.scalars().all():
|
||||||
|
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
|
||||||
|
if hunt.description:
|
||||||
|
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
|
||||||
|
|
||||||
|
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
|
||||||
|
"""Scan annotation text."""
|
||||||
|
ann_result = await self.db.execute(select(Annotation))
|
||||||
|
for ann in ann_result.scalars().all():
|
||||||
|
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
|
||||||
|
|
||||||
|
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
|
||||||
|
"""Scan conversation messages (user messages only)."""
|
||||||
|
msg_result = await self.db.execute(
|
||||||
|
select(Message).where(Message.role == "user")
|
||||||
|
)
|
||||||
|
for msg in msg_result.scalars().all():
|
||||||
|
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
|
||||||
|
'''
|
||||||
|
|
||||||
|
p.write_text(new_text, encoding='utf-8')
|
||||||
|
print('updated scanner.py')
|
||||||
31
_edit_test_api.py
Normal file
31
_edit_test_api.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_api.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
insert='''
|
||||||
|
async def test_hunt_progress(self, client):
|
||||||
|
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
|
||||||
|
hunt_id = create.json()["id"]
|
||||||
|
|
||||||
|
# attach one dataset so progress has scope
|
||||||
|
from tests.conftest import SAMPLE_CSV
|
||||||
|
import io
|
||||||
|
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||||
|
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||||
|
assert up.status_code == 200
|
||||||
|
|
||||||
|
res = await client.get(f"/api/hunts/{hunt_id}/progress")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert body["hunt_id"] == hunt_id
|
||||||
|
assert "progress_percent" in body
|
||||||
|
assert "dataset_total" in body
|
||||||
|
assert "network_status" in body
|
||||||
|
'''
|
||||||
|
needle=''' async def test_get_nonexistent_hunt(self, client):
|
||||||
|
resp = await client.get("/api/hunts/nonexistent-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
'''
|
||||||
|
if needle in t and 'test_hunt_progress' not in t:
|
||||||
|
t=t.replace(needle, needle+'\n'+insert)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated test_api.py')
|
||||||
32
_edit_test_keywords.py
Normal file
32
_edit_test_keywords.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_keywords.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
add='''
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quick_scan_cache_hit(client: AsyncClient):
|
||||||
|
"""Second quick scan should return cache hit metadata."""
|
||||||
|
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
|
||||||
|
tid = theme_res.json()["id"]
|
||||||
|
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||||
|
|
||||||
|
from tests.conftest import SAMPLE_CSV
|
||||||
|
import io
|
||||||
|
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||||
|
upload = await client.post("/api/datasets/upload", files=files)
|
||||||
|
ds_id = upload.json()["id"]
|
||||||
|
|
||||||
|
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||||
|
assert first.status_code == 200
|
||||||
|
assert first.json().get("cache_status") in ("miss", "hit")
|
||||||
|
|
||||||
|
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||||
|
assert second.status_code == 200
|
||||||
|
body = second.json()
|
||||||
|
assert body.get("cache_used") is True
|
||||||
|
assert body.get("cache_status") == "hit"
|
||||||
|
'''
|
||||||
|
if 'test_quick_scan_cache_hit' not in t:
|
||||||
|
t=t + add
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated test_keywords.py')
|
||||||
26
_edit_upload.py
Normal file
26
_edit_upload.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
# import useEffect
|
||||||
|
t=t.replace("import React, { useState, useCallback, useRef } from 'react';","import React, { useState, useCallback, useRef, useEffect } from 'react';")
|
||||||
|
# import HuntProgress type
|
||||||
|
t=t.replace("import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';","import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client';")
|
||||||
|
# add state
|
||||||
|
if 'const [huntProgress, setHuntProgress]' not in t:
|
||||||
|
t=t.replace(" const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');"," const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');\n const [huntProgress, setHuntProgress] = useState<HuntProgress | null>(null);")
|
||||||
|
# add polling effect after hunts list effect
|
||||||
|
marker=" React.useEffect(() => {\n hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});\n }, []);\n"
|
||||||
|
if marker in t and 'setInterval' not in t.split(marker,1)[1][:500]:
|
||||||
|
add='''\n useEffect(() => {\n let timer: any = null;\n let cancelled = false;\n\n const pull = async () => {\n if (!huntId) {\n if (!cancelled) setHuntProgress(null);\n return;\n }\n try {\n const p = await hunts.progress(huntId);\n if (!cancelled) setHuntProgress(p);\n } catch {\n if (!cancelled) setHuntProgress(null);\n }\n };\n\n pull();\n if (huntId) timer = setInterval(pull, 2000);\n return () => { cancelled = true; if (timer) clearInterval(timer); };\n }, [huntId, jobs.length]);\n'''
|
||||||
|
t=t.replace(marker, marker+add)
|
||||||
|
|
||||||
|
# insert master progress UI after overall summary
|
||||||
|
insert_after=''' {overallTotal > 0 && (\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mt: 2 }}>\n <Typography variant="body2" color="text.secondary">\n {overallDone + overallErr} / {overallTotal} files processed\n {overallErr > 0 && ` ({overallErr} failed)`}\n </Typography>\n <Box sx={{ flexGrow: 1 }} />\n {overallDone + overallErr === overallTotal && overallTotal > 0 && (\n <Tooltip title="Clear completed">\n <IconButton size="small" onClick={clearCompleted}><ClearIcon fontSize="small" /></IconButton>\n </Tooltip>\n )}\n </Stack>\n )}\n'''
|
||||||
|
add_block='''\n {huntId && huntProgress && (\n <Paper sx={{ p: 1.5, mt: 1.5 }}>\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>\n <Typography variant="body2" sx={{ fontWeight: 600 }}>\n Master Processing Progress\n </Typography>\n <Chip\n size="small"\n label={huntProgress.status.toUpperCase()}\n color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}\n variant="outlined"\n />\n <Box sx={{ flexGrow: 1 }} />\n <Typography variant="caption" color="text.secondary">\n {huntProgress.progress_percent.toFixed(1)}%\n </Typography>\n </Stack>\n <LinearProgress\n variant="determinate"\n value={Math.max(0, Math.min(100, huntProgress.progress_percent))}\n sx={{ height: 8, borderRadius: 4 }}\n />\n <Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>\n <Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />\n <Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />\n <Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />\n <Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />\n </Stack>\n </Paper>\n )}\n'''
|
||||||
|
if insert_after in t:
|
||||||
|
t=t.replace(insert_after, insert_after+add_block)
|
||||||
|
else:
|
||||||
|
print('warning: summary block not found')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated FileUpload.tsx')
|
||||||
42
_edit_upload2.py
Normal file
42
_edit_upload2.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
marker=''' {/* Per-file progress list */}
|
||||||
|
'''
|
||||||
|
add=''' {huntId && huntProgress && (
|
||||||
|
<Paper sx={{ p: 1.5, mt: 1.5 }}>
|
||||||
|
<Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>
|
||||||
|
<Typography variant="body2" sx={{ fontWeight: 600 }}>
|
||||||
|
Master Processing Progress
|
||||||
|
</Typography>
|
||||||
|
<Chip
|
||||||
|
size="small"
|
||||||
|
label={huntProgress.status.toUpperCase()}
|
||||||
|
color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}
|
||||||
|
variant="outlined"
|
||||||
|
/>
|
||||||
|
<Box sx={{ flexGrow: 1 }} />
|
||||||
|
<Typography variant="caption" color="text.secondary">
|
||||||
|
{huntProgress.progress_percent.toFixed(1)}%
|
||||||
|
</Typography>
|
||||||
|
</Stack>
|
||||||
|
<LinearProgress
|
||||||
|
variant="determinate"
|
||||||
|
value={Math.max(0, Math.min(100, huntProgress.progress_percent))}
|
||||||
|
sx={{ height: 8, borderRadius: 4 }}
|
||||||
|
/>
|
||||||
|
<Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>
|
||||||
|
<Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />
|
||||||
|
<Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />
|
||||||
|
<Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />
|
||||||
|
<Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />
|
||||||
|
</Stack>
|
||||||
|
</Paper>
|
||||||
|
)}
|
||||||
|
|
||||||
|
'''
|
||||||
|
if marker not in t:
|
||||||
|
raise SystemExit('marker not found')
|
||||||
|
t=t.replace(marker, add+marker)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('inserted master progress block')
|
||||||
55
_enforce_scanner_budget.py
Normal file
55
_enforce_scanner_budget.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'from app.config import settings' not in t:
|
||||||
|
t=t.replace('from sqlalchemy.ext.asyncio import AsyncSession\n','from sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\n')
|
||||||
|
|
||||||
|
old=''' import asyncio
|
||||||
|
|
||||||
|
for ds_id, ds_name in ds_map.items():
|
||||||
|
last_id = 0
|
||||||
|
while True:
|
||||||
|
'''
|
||||||
|
new=''' import asyncio
|
||||||
|
|
||||||
|
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||||
|
budget_reached = False
|
||||||
|
|
||||||
|
for ds_id, ds_name in ds_map.items():
|
||||||
|
if max_rows and result.rows_scanned >= max_rows:
|
||||||
|
budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
last_id = 0
|
||||||
|
while True:
|
||||||
|
if max_rows and result.rows_scanned >= max_rows:
|
||||||
|
budget_reached = True
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('scanner loop block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
old2=''' if len(rows) < BATCH_SIZE:
|
||||||
|
break
|
||||||
|
|
||||||
|
'''
|
||||||
|
new2=''' if len(rows) < BATCH_SIZE:
|
||||||
|
break
|
||||||
|
|
||||||
|
if budget_reached:
|
||||||
|
break
|
||||||
|
|
||||||
|
if budget_reached:
|
||||||
|
logger.warning(
|
||||||
|
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||||
|
result.rows_scanned,
|
||||||
|
)
|
||||||
|
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('scanner break block not found')
|
||||||
|
t=t.replace(old2,new2,1)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('added scanner global row budget enforcement')
|
||||||
12
_fix_aup_dep.py
Normal file
12
_fix_aup_dep.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||||
|
'''
|
||||||
|
new=''' }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('runScan deps block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('fixed AUPScanner runScan dependency list')
|
||||||
7
_fix_import_datasets.py
Normal file
7
_fix_import_datasets.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'from app.db.models import ProcessingTask' not in t:
|
||||||
|
t=t.replace('from app.db import get_db\n', 'from app.db import get_db\nfrom app.db.models import ProcessingTask\n')
|
||||||
|
p.write_text(t, encoding='utf-8')
|
||||||
|
print('added ProcessingTask import')
|
||||||
25
_fix_keywords_empty_guard.py
Normal file
25
_fix_keywords_empty_guard.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||||
|
raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")
|
||||||
|
|
||||||
|
'''
|
||||||
|
new=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||||
|
return {
|
||||||
|
"total_hits": 0,
|
||||||
|
"hits": [],
|
||||||
|
"themes_scanned": 0,
|
||||||
|
"keywords_scanned": 0,
|
||||||
|
"rows_scanned": 0,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('scope guard block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('adjusted empty scan guard to return fast empty result (200)')
|
||||||
47
_fix_label_selector_networkmap.py
Normal file
47
_fix_label_selector_networkmap.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# Add label selector in toolbar before refresh button
|
||||||
|
insert_after=""" <TextField
|
||||||
|
size=\"small\"
|
||||||
|
placeholder=\"Search hosts, IPs, users\\u2026\"
|
||||||
|
value={search}
|
||||||
|
onChange={e => setSearch(e.target.value)}
|
||||||
|
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||||
|
slotProps={{
|
||||||
|
input: {
|
||||||
|
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
"""
|
||||||
|
label_ctrl="""
|
||||||
|
<FormControl size=\"small\" sx={{ minWidth: 150 }}>
|
||||||
|
<InputLabel id=\"label-mode-selector\">Labels</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId=\"label-mode-selector\"
|
||||||
|
value={labelMode}
|
||||||
|
label=\"Labels\"
|
||||||
|
onChange={e => setLabelMode(e.target.value as LabelMode)}
|
||||||
|
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
|
||||||
|
>
|
||||||
|
<MenuItem value=\"none\">None</MenuItem>
|
||||||
|
<MenuItem value=\"highlight\">Selected/Search</MenuItem>
|
||||||
|
<MenuItem value=\"all\">All</MenuItem>
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
"""
|
||||||
|
if 'label-mode-selector' not in t:
|
||||||
|
if insert_after not in t:
|
||||||
|
raise SystemExit('search block not found for label selector insertion')
|
||||||
|
t=t.replace(insert_after, insert_after+label_ctrl)
|
||||||
|
|
||||||
|
# Fix useCallback dependency for startAnimLoop
|
||||||
|
old=' }, [canvasSize]);'
|
||||||
|
new=' }, [canvasSize, labelMode]);'
|
||||||
|
if old in t:
|
||||||
|
t=t.replace(old,new,1)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('inserted label selector UI and fixed callback dependency')
|
||||||
10
_fix_last_dep_networkmap.py
Normal file
10
_fix_last_dep_networkmap.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
count=t.count('}, [canvasSize]);')
|
||||||
|
if count:
|
||||||
|
t=t.replace('}, [canvasSize]);','}, [canvasSize, labelMode]);')
|
||||||
|
# In case formatter created spaced variant
|
||||||
|
t=t.replace('}, [canvasSize ]);','}, [canvasSize, labelMode]);')
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched remaining canvasSize callback deps:', count)
|
||||||
71
_harden_aup_scope_ui.py
Normal file
71
_harden_aup_scope_ui.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# Auto-select first hunt with datasets after load
|
||||||
|
old=''' const [tRes, hRes] = await Promise.all([
|
||||||
|
keywords.listThemes(),
|
||||||
|
hunts.list(0, 200),
|
||||||
|
]);
|
||||||
|
setThemes(tRes.themes);
|
||||||
|
setHuntList(hRes.hunts);
|
||||||
|
'''
|
||||||
|
new=''' const [tRes, hRes] = await Promise.all([
|
||||||
|
keywords.listThemes(),
|
||||||
|
hunts.list(0, 200),
|
||||||
|
]);
|
||||||
|
setThemes(tRes.themes);
|
||||||
|
setHuntList(hRes.hunts);
|
||||||
|
if (!selectedHuntId && hRes.hunts.length > 0) {
|
||||||
|
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
|
||||||
|
setSelectedHuntId(best.id);
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('loadData block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
# Guard runScan
|
||||||
|
old2=''' const runScan = useCallback(async () => {
|
||||||
|
setScanning(true);
|
||||||
|
setScanResult(null);
|
||||||
|
try {
|
||||||
|
'''
|
||||||
|
new2=''' const runScan = useCallback(async () => {
|
||||||
|
if (!selectedHuntId) {
|
||||||
|
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (selectedDs.size === 0) {
|
||||||
|
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setScanning(true);
|
||||||
|
setScanResult(null);
|
||||||
|
try {
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('runScan header not found')
|
||||||
|
t=t.replace(old2,new2)
|
||||||
|
|
||||||
|
# update loadData deps
|
||||||
|
old3=''' }, [enqueueSnackbar]);
|
||||||
|
'''
|
||||||
|
new3=''' }, [enqueueSnackbar, selectedHuntId]);
|
||||||
|
'''
|
||||||
|
if old3 not in t:
|
||||||
|
raise SystemExit('loadData deps not found')
|
||||||
|
t=t.replace(old3,new3,1)
|
||||||
|
|
||||||
|
# disable button if no hunt or no datasets
|
||||||
|
old4=''' onClick={runScan} disabled={scanning}
|
||||||
|
'''
|
||||||
|
new4=''' onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
|
||||||
|
'''
|
||||||
|
if old4 not in t:
|
||||||
|
raise SystemExit('scan button props not found')
|
||||||
|
t=t.replace(old4,new4)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('hardened AUPScanner to require explicit hunt/dataset scope')
|
||||||
84
_optimize_keywords_partial_cache.py
Normal file
84
_optimize_keywords_partial_cache.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' if can_use_cache:
|
||||||
|
themes = await scanner._load_themes(body.theme_ids)
|
||||||
|
allowed_theme_names = {t.name for t in themes}
|
||||||
|
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||||
|
|
||||||
|
cached_entries: list[dict] = []
|
||||||
|
missing: list[str] = []
|
||||||
|
for dataset_id in (body.dataset_ids or []):
|
||||||
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if not entry:
|
||||||
|
missing.append(dataset_id)
|
||||||
|
continue
|
||||||
|
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||||
|
|
||||||
|
if not missing and cached_entries:
|
||||||
|
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new=''' if can_use_cache:
|
||||||
|
themes = await scanner._load_themes(body.theme_ids)
|
||||||
|
allowed_theme_names = {t.name for t in themes}
|
||||||
|
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||||
|
|
||||||
|
cached_entries: list[dict] = []
|
||||||
|
missing: list[str] = []
|
||||||
|
for dataset_id in (body.dataset_ids or []):
|
||||||
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if not entry:
|
||||||
|
missing.append(dataset_id)
|
||||||
|
continue
|
||||||
|
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||||
|
|
||||||
|
if not missing and cached_entries:
|
||||||
|
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
missing_entries: list[dict] = []
|
||||||
|
for dataset_id in missing:
|
||||||
|
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
|
||||||
|
keyword_scan_cache.put(dataset_id, partial)
|
||||||
|
missing_entries.append({"result": partial, "built_at": None})
|
||||||
|
|
||||||
|
merged = _merge_cached_results(
|
||||||
|
cached_entries + missing_entries,
|
||||||
|
allowed_theme_names if body.theme_ids else None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": len(cached_entries) > 0,
|
||||||
|
"cache_status": "partial" if cached_entries else "miss",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('cache block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated keyword /scan to use partial cache + scan missing datasets only')
|
||||||
61
_optimize_scanner_keyset.py
Normal file
61
_optimize_scanner_keyset.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
start=t.index(' async def _scan_datasets(')
|
||||||
|
end=t.index(' async def _scan_hunts', start)
|
||||||
|
new_func=''' async def _scan_datasets(
|
||||||
|
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||||
|
) -> None:
|
||||||
|
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||||
|
ds_q = select(Dataset.id, Dataset.name)
|
||||||
|
if dataset_ids:
|
||||||
|
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||||
|
ds_result = await self.db.execute(ds_q)
|
||||||
|
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||||
|
|
||||||
|
if not ds_map:
|
||||||
|
return
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
for ds_id, ds_name in ds_map.items():
|
||||||
|
last_id = 0
|
||||||
|
while True:
|
||||||
|
rows_result = await self.db.execute(
|
||||||
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds_id)
|
||||||
|
.where(DatasetRow.id > last_id)
|
||||||
|
.order_by(DatasetRow.id)
|
||||||
|
.limit(BATCH_SIZE)
|
||||||
|
)
|
||||||
|
rows = rows_result.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
result.rows_scanned += 1
|
||||||
|
data = row.data or {}
|
||||||
|
for col_name, cell_value in data.items():
|
||||||
|
if cell_value is None:
|
||||||
|
continue
|
||||||
|
text = str(cell_value)
|
||||||
|
self._match_text(
|
||||||
|
text,
|
||||||
|
patterns,
|
||||||
|
"dataset_row",
|
||||||
|
row.id,
|
||||||
|
col_name,
|
||||||
|
result.hits,
|
||||||
|
row_index=row.row_index,
|
||||||
|
dataset_name=ds_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_id = rows[-1].id
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if len(rows) < BATCH_SIZE:
|
||||||
|
break
|
||||||
|
|
||||||
|
'''
|
||||||
|
out=t[:start]+new_func+t[end:]
|
||||||
|
p.write_text(out,encoding='utf-8')
|
||||||
|
print('optimized scanner _scan_datasets to keyset pagination')
|
||||||
36
_patch_inventory_stats.py
Normal file
36
_patch_inventory_stats.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' return {
|
||||||
|
"hosts": host_list,
|
||||||
|
"connections": conn_list,
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": len(host_list),
|
||||||
|
"total_datasets_scanned": len(all_datasets),
|
||||||
|
"datasets_with_hosts": ds_with_hosts,
|
||||||
|
"total_rows_scanned": total_rows,
|
||||||
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new=''' return {
|
||||||
|
"hosts": host_list,
|
||||||
|
"connections": conn_list,
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": len(host_list),
|
||||||
|
"total_datasets_scanned": len(all_datasets),
|
||||||
|
"datasets_with_hosts": ds_with_hosts,
|
||||||
|
"total_rows_scanned": total_rows,
|
||||||
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('return block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched inventory stats metadata')
|
||||||
10
_patch_inventory_stats2.py
Normal file
10
_patch_inventory_stats2.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
needle=' "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n'
|
||||||
|
if '"row_budget_per_dataset"' not in t:
|
||||||
|
if needle not in t:
|
||||||
|
raise SystemExit('needle not found')
|
||||||
|
t=t.replace(needle, needle + ' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,\n')
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('inserted inventory budget stats lines')
|
||||||
14
_patch_network_sleep.py
Normal file
14
_patch_network_sleep.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
|
||||||
|
text = p.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
anchor = " useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize]);\n"
|
||||||
|
insert = anchor + "\n const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
|
||||||
|
if "const sleep = (ms: number)" not in text and anchor in text:
|
||||||
|
text = text.replace(anchor, insert)
|
||||||
|
|
||||||
|
text = text.replace("await new Promise(r => setTimeout(r, delayMs + jitter));", "await sleep(delayMs + jitter);")
|
||||||
|
|
||||||
|
p.write_text(text, encoding="utf-8")
|
||||||
|
print("Patched sleep helper + polling awaits")
|
||||||
37
_patch_network_wait.py
Normal file
37
_patch_network_wait.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
|
||||||
|
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
|
||||||
|
text = p.read_text(encoding="utf-8")
|
||||||
|
pattern = re.compile(r"const waitUntilReady = async \(\): Promise<boolean> => \{[\s\S]*?\n\s*\};", re.M)
|
||||||
|
replacement = '''const waitUntilReady = async (): Promise<boolean> => {
|
||||||
|
// Poll inventory-status with exponential backoff until 'ready' (or cancelled)
|
||||||
|
setProgress('Host inventory is being prepared in the background');
|
||||||
|
setLoading(true);
|
||||||
|
let delayMs = 1500;
|
||||||
|
const startedAt = Date.now();
|
||||||
|
for (;;) {
|
||||||
|
const jitter = Math.floor(Math.random() * 250);
|
||||||
|
await new Promise(r => setTimeout(r, delayMs + jitter));
|
||||||
|
if (cancelled) return false;
|
||||||
|
try {
|
||||||
|
const st = await network.inventoryStatus(selectedHuntId);
|
||||||
|
if (cancelled) return false;
|
||||||
|
if (st.status === 'ready') return true;
|
||||||
|
if (Date.now() - startedAt > 5 * 60 * 1000) {
|
||||||
|
setError('Host inventory build timed out. Please retry.');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
|
||||||
|
// still building or none (job may not have started yet) - keep polling
|
||||||
|
} catch {
|
||||||
|
if (cancelled) return false;
|
||||||
|
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};'''
|
||||||
|
new_text, n = pattern.subn(replacement, text, count=1)
|
||||||
|
if n != 1:
|
||||||
|
raise SystemExit(f"Failed to patch waitUntilReady, matches={n}")
|
||||||
|
p.write_text(new_text, encoding="utf-8")
|
||||||
|
print("Patched waitUntilReady")
|
||||||
26
_perf_edit_config_inventory.py
Normal file
26
_perf_edit_config_inventory.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
old=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||||
|
default=25000,
|
||||||
|
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
new=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||||
|
default=5000,
|
||||||
|
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Max unique connection tuples retained during host inventory build",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('network inventory block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated network inventory budgets in config')
|
||||||
164
_perf_edit_host_inventory_budgets.py
Normal file
164
_perf_edit_host_inventory_budgets.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# insert budget vars near existing counters
|
||||||
|
old=''' connections: dict[tuple, int] = defaultdict(int)
|
||||||
|
total_rows = 0
|
||||||
|
ds_with_hosts = 0
|
||||||
|
'''
|
||||||
|
new=''' connections: dict[tuple, int] = defaultdict(int)
|
||||||
|
total_rows = 0
|
||||||
|
ds_with_hosts = 0
|
||||||
|
sampled_dataset_count = 0
|
||||||
|
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||||
|
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||||
|
global_budget_reached = False
|
||||||
|
dropped_connections = 0
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('counter block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
# update batch size and sampled count increments + global budget checks
|
||||||
|
old2=''' batch_size = 10000
|
||||||
|
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||||
|
rows_scanned_this_dataset = 0
|
||||||
|
sampled_dataset = False
|
||||||
|
last_row_index = -1
|
||||||
|
while True:
|
||||||
|
'''
|
||||||
|
new2=''' batch_size = 5000
|
||||||
|
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||||
|
rows_scanned_this_dataset = 0
|
||||||
|
sampled_dataset = False
|
||||||
|
last_row_index = -1
|
||||||
|
while True:
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('batch block not found')
|
||||||
|
t=t.replace(old2,new2)
|
||||||
|
|
||||||
|
old3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||||
|
sampled_dataset = True
|
||||||
|
break
|
||||||
|
|
||||||
|
data = ro.data or {}
|
||||||
|
total_rows += 1
|
||||||
|
rows_scanned_this_dataset += 1
|
||||||
|
'''
|
||||||
|
new3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||||
|
sampled_dataset = True
|
||||||
|
break
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
sampled_dataset = True
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
data = ro.data or {}
|
||||||
|
total_rows += 1
|
||||||
|
rows_scanned_this_dataset += 1
|
||||||
|
'''
|
||||||
|
if old3 not in t:
|
||||||
|
raise SystemExit('row scan block not found')
|
||||||
|
t=t.replace(old3,new3)
|
||||||
|
|
||||||
|
# cap connection map growth
|
||||||
|
old4=''' for c in cols['remote_ip']:
|
||||||
|
rip = _clean(data.get(c))
|
||||||
|
if _is_valid_ip(rip):
|
||||||
|
rport = ''
|
||||||
|
for pc in cols['remote_port']:
|
||||||
|
rport = _clean(data.get(pc))
|
||||||
|
if rport:
|
||||||
|
break
|
||||||
|
connections[(host_key, rip, rport)] += 1
|
||||||
|
'''
|
||||||
|
new4=''' for c in cols['remote_ip']:
|
||||||
|
rip = _clean(data.get(c))
|
||||||
|
if _is_valid_ip(rip):
|
||||||
|
rport = ''
|
||||||
|
for pc in cols['remote_port']:
|
||||||
|
rport = _clean(data.get(pc))
|
||||||
|
if rport:
|
||||||
|
break
|
||||||
|
conn_key = (host_key, rip, rport)
|
||||||
|
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||||
|
dropped_connections += 1
|
||||||
|
continue
|
||||||
|
connections[conn_key] += 1
|
||||||
|
'''
|
||||||
|
if old4 not in t:
|
||||||
|
raise SystemExit('connection block not found')
|
||||||
|
t=t.replace(old4,new4)
|
||||||
|
|
||||||
|
# sampled_dataset counter
|
||||||
|
old5=''' if sampled_dataset:
|
||||||
|
logger.info(
|
||||||
|
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||||
|
ds.id,
|
||||||
|
rows_scanned_this_dataset,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
new5=''' if sampled_dataset:
|
||||||
|
sampled_dataset_count += 1
|
||||||
|
logger.info(
|
||||||
|
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||||
|
ds.id,
|
||||||
|
rows_scanned_this_dataset,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
if old5 not in t:
|
||||||
|
raise SystemExit('sampled block not found')
|
||||||
|
t=t.replace(old5,new5)
|
||||||
|
|
||||||
|
# break dataset loop if global budget reached
|
||||||
|
old6=''' if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Post-process hosts
|
||||||
|
'''
|
||||||
|
new6=''' if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
if global_budget_reached:
|
||||||
|
logger.info(
|
||||||
|
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||||
|
hunt_id,
|
||||||
|
total_rows,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Post-process hosts
|
||||||
|
'''
|
||||||
|
if old6 not in t:
|
||||||
|
raise SystemExit('post-process boundary block not found')
|
||||||
|
t=t.replace(old6,new6)
|
||||||
|
|
||||||
|
# add stats
|
||||||
|
old7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||||
|
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||||
|
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||||
|
"sampled_datasets": sampled_dataset_count,
|
||||||
|
"global_budget_reached": global_budget_reached,
|
||||||
|
"dropped_connections": dropped_connections,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old7 not in t:
|
||||||
|
raise SystemExit('stats block not found')
|
||||||
|
t=t.replace(old7,new7)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('updated host inventory with global row and connection budgets')
|
||||||
39
_perf_edit_networkmap_render.py
Normal file
39
_perf_edit_networkmap_render.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
repls={
|
||||||
|
"const LARGE_HUNT_SUBGRAPH_HOSTS = 350;":"const LARGE_HUNT_SUBGRAPH_HOSTS = 220;",
|
||||||
|
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;":"const LARGE_HUNT_SUBGRAPH_EDGES = 1200;",
|
||||||
|
"const RENDER_SIMPLIFY_NODE_THRESHOLD = 220;":"const RENDER_SIMPLIFY_NODE_THRESHOLD = 120;",
|
||||||
|
"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;":"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500;",
|
||||||
|
"const EDGE_DRAW_TARGET = 1000;":"const EDGE_DRAW_TARGET = 600;"
|
||||||
|
}
|
||||||
|
for a,b in repls.items():
|
||||||
|
if a not in t:
|
||||||
|
raise SystemExit(f'missing constant: {a}')
|
||||||
|
t=t.replace(a,b)
|
||||||
|
|
||||||
|
old=''' // Then label hit (so clicking text works too)
|
||||||
|
for (const n of graph.nodes) {
|
||||||
|
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new=''' // Then label hit (so clicking text works too on manageable graph sizes)
|
||||||
|
if (graph.nodes.length <= 220) {
|
||||||
|
for (const n of graph.nodes) {
|
||||||
|
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('label hit block not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
old2='simulate(g, w / 2, h / 2, 60);'
|
||||||
|
if t.count(old2) < 2:
|
||||||
|
raise SystemExit('expected two simulate calls')
|
||||||
|
t=t.replace(old2,'simulate(g, w / 2, h / 2, 20);',1)
|
||||||
|
t=t.replace(old2,'simulate(g, w / 2, h / 2, 30);',1)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('tightened network map rendering + load limits')
|
||||||
107
_perf_patch_backend.py
Normal file
107
_perf_patch_backend.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
# config updates
|
||||||
|
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=cfg.read_text(encoding='utf-8')
|
||||||
|
anchor=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||||
|
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
ins=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||||
|
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||||
|
default=200000,
|
||||||
|
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
|
||||||
|
if anchor not in t:
|
||||||
|
raise SystemExit('config network anchor not found')
|
||||||
|
t=t.replace(anchor,ins)
|
||||||
|
cfg.write_text(t,encoding='utf-8')
|
||||||
|
|
||||||
|
# host inventory updates
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'from app.config import settings' not in t:
|
||||||
|
t=t.replace('from app.db.models import Dataset, DatasetRow\n', 'from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
|
||||||
|
|
||||||
|
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
|
||||||
|
|
||||||
|
old=''' for ro in rows:
|
||||||
|
data = ro.data or {}
|
||||||
|
total_rows += 1
|
||||||
|
|
||||||
|
fqdn = ''
|
||||||
|
'''
|
||||||
|
new=''' for ro in rows:
|
||||||
|
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||||
|
sampled_dataset = True
|
||||||
|
break
|
||||||
|
|
||||||
|
data = ro.data or {}
|
||||||
|
total_rows += 1
|
||||||
|
rows_scanned_this_dataset += 1
|
||||||
|
|
||||||
|
fqdn = ''
|
||||||
|
'''
|
||||||
|
if old not in t:
|
||||||
|
raise SystemExit('row loop anchor not found')
|
||||||
|
t=t.replace(old,new)
|
||||||
|
|
||||||
|
old2=''' last_row_index = rows[-1].row_index
|
||||||
|
if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
new2=''' if sampled_dataset:
|
||||||
|
logger.info(
|
||||||
|
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||||
|
ds.id,
|
||||||
|
rows_scanned_this_dataset,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
last_row_index = rows[-1].row_index
|
||||||
|
if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
'''
|
||||||
|
if old2 not in t:
|
||||||
|
raise SystemExit('batch loop end anchor not found')
|
||||||
|
t=t.replace(old2,new2)
|
||||||
|
|
||||||
|
old3=''' return {
|
||||||
|
"hosts": host_list,
|
||||||
|
"connections": conn_list,
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": len(host_list),
|
||||||
|
"total_datasets_scanned": len(all_datasets),
|
||||||
|
"datasets_with_hosts": ds_with_hosts,
|
||||||
|
"total_rows_scanned": total_rows,
|
||||||
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
new3=''' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hosts": host_list,
|
||||||
|
"connections": conn_list,
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": len(host_list),
|
||||||
|
"total_datasets_scanned": len(all_datasets),
|
||||||
|
"datasets_with_hosts": ds_with_hosts,
|
||||||
|
"total_rows_scanned": total_rows,
|
||||||
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"sampled_mode": sampled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
if old3 not in t:
|
||||||
|
raise SystemExit('return stats anchor not found')
|
||||||
|
t=t.replace(old3,new3)
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched config + host inventory row budget')
|
||||||
38
_perf_patch_backend2.py
Normal file
38
_perf_patch_backend2.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||||
|
t=cfg.read_text(encoding='utf-8')
|
||||||
|
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
|
||||||
|
t=t.replace(
|
||||||
|
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||||
|
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
''',
|
||||||
|
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||||
|
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||||
|
default=200000,
|
||||||
|
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||||
|
)
|
||||||
|
''')
|
||||||
|
cfg.write_text(t,encoding='utf-8')
|
||||||
|
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
if 'from app.config import settings' not in t:
|
||||||
|
t=t.replace('from app.db.models import Dataset, DatasetRow\n','from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
|
||||||
|
|
||||||
|
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n',
|
||||||
|
' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
|
||||||
|
|
||||||
|
t=t.replace(' for ro in rows:\n data = ro.data or {}\n total_rows += 1\n\n',
|
||||||
|
' for ro in rows:\n if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:\n sampled_dataset = True\n break\n\n data = ro.data or {}\n total_rows += 1\n rows_scanned_this_dataset += 1\n\n')
|
||||||
|
|
||||||
|
t=t.replace(' last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n',
|
||||||
|
' if sampled_dataset:\n logger.info(\n "Host inventory row budget reached for dataset %s (%d rows)",\n ds.id,\n rows_scanned_this_dataset,\n )\n break\n\n last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n')
|
||||||
|
|
||||||
|
t=t.replace(' return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n },\n }\n',
|
||||||
|
' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0\n\n return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": sampled,\n },\n }\n')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched backend inventory performance settings')
|
||||||
220
_perf_patch_networkmap.py
Normal file
220
_perf_patch_networkmap.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
# constants
|
||||||
|
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
|
||||||
|
t=t.replace(
|
||||||
|
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n",
|
||||||
|
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n")
|
||||||
|
|
||||||
|
# drawBackground signature
|
||||||
|
t_old='''function drawBackground(
|
||||||
|
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
|
||||||
|
) {
|
||||||
|
'''
|
||||||
|
if t_old in t:
|
||||||
|
t=t.replace(t_old,
|
||||||
|
'''function drawBackground(
|
||||||
|
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
''')
|
||||||
|
|
||||||
|
# skip grid when simplify
|
||||||
|
if 'if (!simplify) {' not in t:
|
||||||
|
t=t.replace(
|
||||||
|
''' ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||||
|
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||||
|
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
ctx.fillStyle = GRID_DOT_COLOR;
|
||||||
|
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||||
|
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||||
|
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.restore();
|
||||||
|
''',
|
||||||
|
''' if (!simplify) {
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||||
|
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||||
|
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
ctx.fillStyle = GRID_DOT_COLOR;
|
||||||
|
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||||
|
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||||
|
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.restore();
|
||||||
|
}
|
||||||
|
''')
|
||||||
|
|
||||||
|
# drawEdges signature
|
||||||
|
t=t.replace('''function drawEdges(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
nodeMap: Map<string, GNode>, animTime: number,
|
||||||
|
) {
|
||||||
|
for (const e of graph.edges) {
|
||||||
|
''',
|
||||||
|
'''function drawEdges(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
nodeMap: Map<string, GNode>, animTime: number,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
|
||||||
|
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
|
||||||
|
const e = graph.edges[ei];
|
||||||
|
''')
|
||||||
|
|
||||||
|
# simplify edge path
|
||||||
|
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
|
||||||
|
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||||
|
|
||||||
|
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
|
||||||
|
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||||
|
|
||||||
|
# reduce glow when simplify
|
||||||
|
t=t.replace(''' ctx.save();
|
||||||
|
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
|
||||||
|
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
|
||||||
|
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
|
||||||
|
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
|
||||||
|
ctx.stroke(); ctx.restore();
|
||||||
|
''',
|
||||||
|
''' if (!simplify) {
|
||||||
|
ctx.save();
|
||||||
|
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
|
||||||
|
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
|
||||||
|
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
|
||||||
|
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
|
||||||
|
ctx.stroke(); ctx.restore();
|
||||||
|
}
|
||||||
|
''')
|
||||||
|
|
||||||
|
# drawLabels signature and early return
|
||||||
|
t=t.replace('''function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
) {
|
||||||
|
''',
|
||||||
|
'''function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
''')
|
||||||
|
|
||||||
|
if 'if (simplify && !search && !hovered && !selected) {' not in t:
|
||||||
|
t=t.replace(' const dimmed = search.length > 0;\n',
|
||||||
|
' const dimmed = search.length > 0;\n if (simplify && !search && !hovered && !selected) {\n return;\n }\n')
|
||||||
|
|
||||||
|
# drawGraph adapt
|
||||||
|
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
|
||||||
|
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||||
|
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
|
||||||
|
ctx.restore();
|
||||||
|
''',
|
||||||
|
''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
|
||||||
|
drawBackground(ctx, w, h, vp, dpr, simplify);
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
|
||||||
|
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||||
|
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
|
||||||
|
ctx.restore();
|
||||||
|
''')
|
||||||
|
|
||||||
|
# hover RAF ref
|
||||||
|
if 'const hoverRafRef = useRef<number>(0);' not in t:
|
||||||
|
t=t.replace(' const graphRef = useRef<Graph | null>(null);\n', ' const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
|
||||||
|
|
||||||
|
# throttle hover hit test on mousemove
|
||||||
|
old_mm=''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||||
|
setHovered(node?.id ?? null);
|
||||||
|
}, [graph, redraw, startAnimLoop]);
|
||||||
|
'''
|
||||||
|
new_mm=''' cancelAnimationFrame(hoverRafRef.current);
|
||||||
|
const clientX = e.clientX;
|
||||||
|
const clientY = e.clientY;
|
||||||
|
hoverRafRef.current = requestAnimationFrame(() => {
|
||||||
|
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
|
||||||
|
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
|
||||||
|
});
|
||||||
|
}, [graph, redraw, startAnimLoop]);
|
||||||
|
'''
|
||||||
|
if old_mm in t:
|
||||||
|
t=t.replace(old_mm,new_mm)
|
||||||
|
|
||||||
|
# cleanup hover raf on unmount in existing animation cleanup effect
|
||||||
|
if 'cancelAnimationFrame(hoverRafRef.current);' not in t:
|
||||||
|
t=t.replace(''' useEffect(() => {
|
||||||
|
if (graph) startAnimLoop();
|
||||||
|
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
|
||||||
|
}, [graph, startAnimLoop]);
|
||||||
|
''',
|
||||||
|
''' useEffect(() => {
|
||||||
|
if (graph) startAnimLoop();
|
||||||
|
return () => {
|
||||||
|
cancelAnimationFrame(animFrameRef.current);
|
||||||
|
cancelAnimationFrame(hoverRafRef.current);
|
||||||
|
isAnimatingRef.current = false;
|
||||||
|
};
|
||||||
|
}, [graph, startAnimLoop]);
|
||||||
|
''')
|
||||||
|
|
||||||
|
# connectedNodes optimization map
|
||||||
|
if 'const nodeById = useMemo(() => {' not in t:
|
||||||
|
t=t.replace(''' const connectionCount = selectedNode && graph
|
||||||
|
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
const connectedNodes = useMemo(() => {
|
||||||
|
''',
|
||||||
|
''' const connectionCount = selectedNode && graph
|
||||||
|
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
const nodeById = useMemo(() => {
|
||||||
|
const m = new Map<string, GNode>();
|
||||||
|
if (!graph) return m;
|
||||||
|
for (const n of graph.nodes) m.set(n.id, n);
|
||||||
|
return m;
|
||||||
|
}, [graph]);
|
||||||
|
|
||||||
|
const connectedNodes = useMemo(() => {
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace(''' const n = graph.nodes.find(x => x.id === e.target);
|
||||||
|
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||||
|
} else if (e.target === selectedNode.id) {
|
||||||
|
const n = graph.nodes.find(x => x.id === e.source);
|
||||||
|
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||||
|
''',
|
||||||
|
''' const n = nodeById.get(e.target);
|
||||||
|
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||||
|
} else if (e.target === selectedNode.id) {
|
||||||
|
const n = nodeById.get(e.source);
|
||||||
|
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace(' }, [selectedNode, graph]);\n', ' }, [selectedNode, graph, nodeById]);\n')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched NetworkMap adaptive render + hover throttle')
|
||||||
153
_perf_patch_networkmap2.py
Normal file
153
_perf_patch_networkmap2.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
|
||||||
|
t=t.replace('const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n', 'const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n')
|
||||||
|
|
||||||
|
t=t.replace('function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n) {', 'function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n simplify: boolean,\n) {')
|
||||||
|
|
||||||
|
t=t.replace(''' ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||||
|
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||||
|
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
ctx.fillStyle = GRID_DOT_COLOR;
|
||||||
|
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||||
|
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||||
|
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.restore();
|
||||||
|
''',''' if (!simplify) {
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||||
|
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||||
|
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||||
|
ctx.fillStyle = GRID_DOT_COLOR;
|
||||||
|
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||||
|
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||||
|
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.restore();
|
||||||
|
}
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace('''function drawEdges(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
nodeMap: Map<string, GNode>, animTime: number,
|
||||||
|
) {
|
||||||
|
for (const e of graph.edges) {
|
||||||
|
''','''function drawEdges(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
nodeMap: Map<string, GNode>, animTime: number,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
|
||||||
|
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
|
||||||
|
const e = graph.edges[ei];
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||||
|
|
||||||
|
t=t.replace('''function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
) {
|
||||||
|
const dimmed = search.length > 0;
|
||||||
|
''','''function drawLabels(
|
||||||
|
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||||
|
hovered: string | null, selected: string | null,
|
||||||
|
search: string, matchSet: Set<string>, vp: Viewport,
|
||||||
|
simplify: boolean,
|
||||||
|
) {
|
||||||
|
const dimmed = search.length > 0;
|
||||||
|
if (simplify && !search && !hovered && !selected) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
|
||||||
|
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||||
|
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
|
||||||
|
ctx.restore();
|
||||||
|
''',''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
|
||||||
|
drawBackground(ctx, w, h, vp, dpr, simplify);
|
||||||
|
ctx.save();
|
||||||
|
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||||
|
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||||
|
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
|
||||||
|
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||||
|
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
|
||||||
|
ctx.restore();
|
||||||
|
''')
|
||||||
|
|
||||||
|
if 'const hoverRafRef = useRef<number>(0);' not in t:
|
||||||
|
t=t.replace('const graphRef = useRef<Graph | null>(null);\n', 'const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
|
||||||
|
|
||||||
|
t=t.replace(''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||||
|
setHovered(node?.id ?? null);
|
||||||
|
}, [graph, redraw, startAnimLoop]);
|
||||||
|
''',''' cancelAnimationFrame(hoverRafRef.current);
|
||||||
|
const clientX = e.clientX;
|
||||||
|
const clientY = e.clientY;
|
||||||
|
hoverRafRef.current = requestAnimationFrame(() => {
|
||||||
|
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
|
||||||
|
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
|
||||||
|
});
|
||||||
|
}, [graph, redraw, startAnimLoop]);
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace(''' useEffect(() => {
|
||||||
|
if (graph) startAnimLoop();
|
||||||
|
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
|
||||||
|
}, [graph, startAnimLoop]);
|
||||||
|
''',''' useEffect(() => {
|
||||||
|
if (graph) startAnimLoop();
|
||||||
|
return () => {
|
||||||
|
cancelAnimationFrame(animFrameRef.current);
|
||||||
|
cancelAnimationFrame(hoverRafRef.current);
|
||||||
|
isAnimatingRef.current = false;
|
||||||
|
};
|
||||||
|
}, [graph, startAnimLoop]);
|
||||||
|
''')
|
||||||
|
|
||||||
|
if 'const nodeById = useMemo(() => {' not in t:
|
||||||
|
t=t.replace(''' const connectionCount = selectedNode && graph
|
||||||
|
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
const connectedNodes = useMemo(() => {
|
||||||
|
''',''' const connectionCount = selectedNode && graph
|
||||||
|
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||||
|
: 0;
|
||||||
|
|
||||||
|
const nodeById = useMemo(() => {
|
||||||
|
const m = new Map<string, GNode>();
|
||||||
|
if (!graph) return m;
|
||||||
|
for (const n of graph.nodes) m.set(n.id, n);
|
||||||
|
return m;
|
||||||
|
}, [graph]);
|
||||||
|
|
||||||
|
const connectedNodes = useMemo(() => {
|
||||||
|
''')
|
||||||
|
|
||||||
|
t=t.replace('const n = graph.nodes.find(x => x.id === e.target);','const n = nodeById.get(e.target);')
|
||||||
|
t=t.replace('const n = graph.nodes.find(x => x.id === e.source);','const n = nodeById.get(e.source);')
|
||||||
|
t=t.replace(' }, [selectedNode, graph]);',' }, [selectedNode, graph, nodeById]);')
|
||||||
|
|
||||||
|
p.write_text(t,encoding='utf-8')
|
||||||
|
print('patched NetworkMap performance')
|
||||||
227
_perf_replace_build_host_inventory.py
Normal file
227
_perf_replace_build_host_inventory.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||||
|
t=p.read_text(encoding='utf-8')
|
||||||
|
start=t.index('async def build_host_inventory(')
|
||||||
|
# find end of function by locating '\n\n' before EOF after ' }\n'
|
||||||
|
end=t.index('\n\n', start)
|
||||||
|
# need proper end: first double newline after function may occur in docstring? compute by searching for '\n\n' after ' }\n' near end
|
||||||
|
ret_idx=t.rfind(' }')
|
||||||
|
# safer locate end as last occurrence of '\n }\n' after start, then function ends next newline
|
||||||
|
end=t.find('\n\n', ret_idx)
|
||||||
|
if end==-1:
|
||||||
|
end=len(t)
|
||||||
|
new_func='''async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||||
|
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||||
|
|
||||||
|
Returns dict with 'hosts', 'connections', and 'stats'.
|
||||||
|
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
|
||||||
|
"""
|
||||||
|
ds_result = await db.execute(
|
||||||
|
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||||
|
)
|
||||||
|
all_datasets = ds_result.scalars().all()
|
||||||
|
|
||||||
|
if not all_datasets:
|
||||||
|
return {"hosts": [], "connections": [], "stats": {
|
||||||
|
"total_hosts": 0, "total_datasets_scanned": 0,
|
||||||
|
"total_rows_scanned": 0,
|
||||||
|
}}
|
||||||
|
|
||||||
|
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||||
|
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||||
|
connections: dict[tuple, int] = defaultdict(int)
|
||||||
|
total_rows = 0
|
||||||
|
ds_with_hosts = 0
|
||||||
|
sampled_dataset_count = 0
|
||||||
|
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||||
|
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||||
|
global_budget_reached = False
|
||||||
|
dropped_connections = 0
|
||||||
|
|
||||||
|
for ds in all_datasets:
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
cols = _identify_columns(ds)
|
||||||
|
if not cols['fqdn'] and not cols['host_id']:
|
||||||
|
continue
|
||||||
|
ds_with_hosts += 1
|
||||||
|
|
||||||
|
batch_size = 5000
|
||||||
|
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||||
|
rows_scanned_this_dataset = 0
|
||||||
|
sampled_dataset = False
|
||||||
|
last_row_index = -1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
sampled_dataset = True
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
rr = await db.execute(
|
||||||
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds.id)
|
||||||
|
.where(DatasetRow.row_index > last_row_index)
|
||||||
|
.order_by(DatasetRow.row_index)
|
||||||
|
.limit(batch_size)
|
||||||
|
)
|
||||||
|
rows = rr.scalars().all()
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
|
||||||
|
for ro in rows:
|
||||||
|
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||||
|
sampled_dataset = True
|
||||||
|
break
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
sampled_dataset = True
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
data = ro.data or {}
|
||||||
|
total_rows += 1
|
||||||
|
rows_scanned_this_dataset += 1
|
||||||
|
|
||||||
|
fqdn = ''
|
||||||
|
for c in cols['fqdn']:
|
||||||
|
fqdn = _clean(data.get(c))
|
||||||
|
if fqdn:
|
||||||
|
break
|
||||||
|
client_id = ''
|
||||||
|
for c in cols['host_id']:
|
||||||
|
client_id = _clean(data.get(c))
|
||||||
|
if client_id:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not fqdn and not client_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
host_key = fqdn or client_id
|
||||||
|
|
||||||
|
if host_key not in hosts:
|
||||||
|
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
|
||||||
|
hosts[host_key] = {
|
||||||
|
'id': host_key,
|
||||||
|
'hostname': short or client_id,
|
||||||
|
'fqdn': fqdn,
|
||||||
|
'client_id': client_id,
|
||||||
|
'ips': set(),
|
||||||
|
'os': '',
|
||||||
|
'users': set(),
|
||||||
|
'datasets': set(),
|
||||||
|
'row_count': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
h = hosts[host_key]
|
||||||
|
h['datasets'].add(ds.name)
|
||||||
|
h['row_count'] += 1
|
||||||
|
if client_id and not h['client_id']:
|
||||||
|
h['client_id'] = client_id
|
||||||
|
|
||||||
|
for c in cols['username']:
|
||||||
|
u = _extract_username(_clean(data.get(c)))
|
||||||
|
if u:
|
||||||
|
h['users'].add(u)
|
||||||
|
|
||||||
|
for c in cols['local_ip']:
|
||||||
|
ip = _clean(data.get(c))
|
||||||
|
if _is_valid_ip(ip):
|
||||||
|
h['ips'].add(ip)
|
||||||
|
ip_to_host[ip] = host_key
|
||||||
|
|
||||||
|
for c in cols['os']:
|
||||||
|
ov = _clean(data.get(c))
|
||||||
|
if ov and not h['os']:
|
||||||
|
h['os'] = ov
|
||||||
|
|
||||||
|
for c in cols['remote_ip']:
|
||||||
|
rip = _clean(data.get(c))
|
||||||
|
if _is_valid_ip(rip):
|
||||||
|
rport = ''
|
||||||
|
for pc in cols['remote_port']:
|
||||||
|
rport = _clean(data.get(pc))
|
||||||
|
if rport:
|
||||||
|
break
|
||||||
|
conn_key = (host_key, rip, rport)
|
||||||
|
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||||
|
dropped_connections += 1
|
||||||
|
continue
|
||||||
|
connections[conn_key] += 1
|
||||||
|
|
||||||
|
if sampled_dataset:
|
||||||
|
sampled_dataset_count += 1
|
||||||
|
logger.info(
|
||||||
|
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||||
|
ds.id,
|
||||||
|
rows_scanned_this_dataset,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
last_row_index = rows[-1].row_index
|
||||||
|
if len(rows) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
if global_budget_reached:
|
||||||
|
logger.info(
|
||||||
|
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||||
|
hunt_id,
|
||||||
|
total_rows,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Post-process hosts
|
||||||
|
for h in hosts.values():
|
||||||
|
if not h['os'] and h['fqdn']:
|
||||||
|
h['os'] = _infer_os(h['fqdn'])
|
||||||
|
h['ips'] = sorted(h['ips'])
|
||||||
|
h['users'] = sorted(h['users'])
|
||||||
|
h['datasets'] = sorted(h['datasets'])
|
||||||
|
|
||||||
|
# Build connections, resolving IPs to host keys
|
||||||
|
conn_list = []
|
||||||
|
seen = set()
|
||||||
|
for (src, dst_ip, dst_port), cnt in connections.items():
|
||||||
|
if dst_ip in _IGNORE_IPS:
|
||||||
|
continue
|
||||||
|
dst_host = ip_to_host.get(dst_ip, '')
|
||||||
|
if dst_host == src:
|
||||||
|
continue
|
||||||
|
key = tuple(sorted([src, dst_host or dst_ip]))
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
conn_list.append({
|
||||||
|
'source': src,
|
||||||
|
'target': dst_host or dst_ip,
|
||||||
|
'target_ip': dst_ip,
|
||||||
|
'port': dst_port,
|
||||||
|
'count': cnt,
|
||||||
|
})
|
||||||
|
|
||||||
|
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hosts": host_list,
|
||||||
|
"connections": conn_list,
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": len(host_list),
|
||||||
|
"total_datasets_scanned": len(all_datasets),
|
||||||
|
"datasets_with_hosts": ds_with_hosts,
|
||||||
|
"total_rows_scanned": total_rows,
|
||||||
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||||
|
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||||
|
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||||
|
"sampled_datasets": sampled_dataset_count,
|
||||||
|
"global_budget_reached": global_budget_reached,
|
||||||
|
"dropped_connections": dropped_connections,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
out=t[:start]+new_func+t[end:]
|
||||||
|
p.write_text(out,encoding='utf-8')
|
||||||
|
print('replaced build_host_inventory with hard-budget fast mode')
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
"""add playbooks, playbook_steps, saved_searches tables
|
||||||
|
|
||||||
|
Revision ID: b2c3d4e5f6a7
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2026-02-21 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "b2c3d4e5f6a7"
|
||||||
|
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add display_name to users table
|
||||||
|
with op.batch_alter_table("users") as batch_op:
|
||||||
|
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
|
||||||
|
|
||||||
|
# Create playbooks table
|
||||||
|
op.create_table(
|
||||||
|
"playbooks",
|
||||||
|
sa.Column("id", sa.String(32), primary_key=True),
|
||||||
|
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||||
|
sa.Column("is_template", sa.Boolean(), server_default="0"),
|
||||||
|
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||||
|
sa.Column("status", sa.String(20), server_default="active"),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create playbook_steps table
|
||||||
|
op.create_table(
|
||||||
|
"playbook_steps",
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||||
|
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("order_index", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("title", sa.String(256), nullable=False),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("step_type", sa.String(32), server_default="manual"),
|
||||||
|
sa.Column("target_route", sa.String(256), nullable=True),
|
||||||
|
sa.Column("is_completed", sa.Boolean(), server_default="0"),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("notes", sa.Text(), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
|
||||||
|
|
||||||
|
# Create saved_searches table
|
||||||
|
op.create_table(
|
||||||
|
"saved_searches",
|
||||||
|
sa.Column("id", sa.String(32), primary_key=True),
|
||||||
|
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||||
|
sa.Column("description", sa.Text(), nullable=True),
|
||||||
|
sa.Column("search_type", sa.String(32), nullable=False),
|
||||||
|
sa.Column("query_params", sa.JSON(), nullable=False),
|
||||||
|
sa.Column("threshold", sa.Float(), nullable=True),
|
||||||
|
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||||
|
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("last_result_count", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("saved_searches")
|
||||||
|
op.drop_table("playbook_steps")
|
||||||
|
op.drop_table("playbooks")
|
||||||
|
with op.batch_alter_table("users") as batch_op:
|
||||||
|
batch_op.drop_column("display_name")
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
"""add processing_tasks table
|
||||||
|
|
||||||
|
Revision ID: c3d4e5f6a7b8
|
||||||
|
Revises: b2c3d4e5f6a7
|
||||||
|
Create Date: 2026-02-22 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "c3d4e5f6a7b8"
|
||||||
|
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"processing_tasks",
|
||||||
|
sa.Column("id", sa.String(32), primary_key=True),
|
||||||
|
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
|
||||||
|
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
|
||||||
|
sa.Column("job_id", sa.String(64), nullable=True),
|
||||||
|
sa.Column("stage", sa.String(64), nullable=False),
|
||||||
|
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||||
|
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("message", sa.Text(), nullable=True),
|
||||||
|
sa.Column("error", sa.Text(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||||
|
)
|
||||||
|
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
|
||||||
|
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
|
||||||
|
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
|
||||||
|
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
|
||||||
|
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
|
||||||
|
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
|
||||||
|
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
|
||||||
|
op.drop_table("processing_tasks")
|
||||||
@@ -1,16 +1,18 @@
|
|||||||
"""Analyst-assist agent module for ThreatHunt.
|
"""Analyst-assist agent module for ThreatHunt.
|
||||||
|
|
||||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||||
Agents are advisory only and do not execute actions or modify data.
|
Agents are advisory only and do not execute actions or modify data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .core import ThreatHuntAgent
|
from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||||
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider
|
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ThreatHuntAgent",
|
"ThreatHuntAgent",
|
||||||
"LLMProvider",
|
"AgentContext",
|
||||||
"LocalProvider",
|
"AgentResponse",
|
||||||
"NetworkedProvider",
|
"Perspective",
|
||||||
"OnlineProvider",
|
"OllamaProvider",
|
||||||
|
"OpenWebUIProvider",
|
||||||
|
"EmbeddingProvider",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
"""Core ThreatHunt analyst-assist agent.
|
|
||||||
|
|
||||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
|
||||||
Agents are advisory only - no execution, no alerts, no data modifications.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from .providers import LLMProvider, get_provider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentContext(BaseModel):
|
|
||||||
"""Context for agent guidance requests."""
|
|
||||||
|
|
||||||
query: str = Field(
|
|
||||||
..., description="Analyst question or request for guidance"
|
|
||||||
)
|
|
||||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
|
||||||
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
|
|
||||||
host_identifier: Optional[str] = Field(
|
|
||||||
None, description="Host name, IP, or identifier"
|
|
||||||
)
|
|
||||||
data_summary: Optional[str] = Field(
|
|
||||||
None, description="Brief description of uploaded data"
|
|
||||||
)
|
|
||||||
conversation_history: Optional[list[dict]] = Field(
|
|
||||||
default_factory=list, description="Previous messages in conversation"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentResponse(BaseModel):
|
|
||||||
"""Response from analyst-assist agent."""
|
|
||||||
|
|
||||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
|
||||||
confidence: float = Field(
|
|
||||||
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
|
|
||||||
)
|
|
||||||
suggested_pivots: list[str] = Field(
|
|
||||||
default_factory=list, description="Suggested analytical directions"
|
|
||||||
)
|
|
||||||
suggested_filters: list[str] = Field(
|
|
||||||
default_factory=list, description="Suggested data filters or queries"
|
|
||||||
)
|
|
||||||
caveats: Optional[str] = Field(
|
|
||||||
None, description="Assumptions, limitations, or caveats"
|
|
||||||
)
|
|
||||||
reasoning: Optional[str] = Field(
|
|
||||||
None, description="Explanation of how guidance was generated"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreatHuntAgent:
|
|
||||||
"""Analyst-assist agent for ThreatHunt.
|
|
||||||
|
|
||||||
Provides guidance on:
|
|
||||||
- Interpreting CSV artifact data
|
|
||||||
- Suggesting analytical pivots and filters
|
|
||||||
- Forming and testing hypotheses
|
|
||||||
|
|
||||||
Policy:
|
|
||||||
- Advisory guidance only (no execution)
|
|
||||||
- No database or schema changes
|
|
||||||
- No alert escalation
|
|
||||||
- Transparent reasoning
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, provider: Optional[LLMProvider] = None):
|
|
||||||
"""Initialize agent with LLM provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: LLM provider instance. If None, uses get_provider() with auto mode.
|
|
||||||
"""
|
|
||||||
if provider is None:
|
|
||||||
try:
|
|
||||||
provider = get_provider("auto")
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.warning(f"Could not initialize default provider: {e}")
|
|
||||||
provider = None
|
|
||||||
|
|
||||||
self.provider = provider
|
|
||||||
self.system_prompt = self._build_system_prompt()
|
|
||||||
|
|
||||||
def _build_system_prompt(self) -> str:
|
|
||||||
"""Build the system prompt that governs agent behavior."""
|
|
||||||
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
|
||||||
|
|
||||||
Your role:
|
|
||||||
- Interpret and explain CSV artifact data from Velociraptor
|
|
||||||
- Suggest analytical pivots, filters, and hypotheses
|
|
||||||
- Highlight anomalies, patterns, or points of interest
|
|
||||||
- Guide analysts without replacing their judgment
|
|
||||||
|
|
||||||
Your constraints:
|
|
||||||
- You ONLY provide guidance and suggestions
|
|
||||||
- You do NOT execute actions or tools
|
|
||||||
- You do NOT modify data or escalate alerts
|
|
||||||
- You do NOT make autonomous decisions
|
|
||||||
- You ONLY analyze data presented to you
|
|
||||||
- You explain your reasoning transparently
|
|
||||||
- You acknowledge limitations and assumptions
|
|
||||||
- You suggest next investigative steps
|
|
||||||
|
|
||||||
When responding:
|
|
||||||
1. Start with a clear, direct answer to the query
|
|
||||||
2. Explain your reasoning based on the data context provided
|
|
||||||
3. Suggest 2-4 analytical pivots the analyst might explore
|
|
||||||
4. Suggest 2-4 data filters or queries that might be useful
|
|
||||||
5. Include relevant caveats or assumptions
|
|
||||||
6. Be honest about what you cannot determine from the data
|
|
||||||
|
|
||||||
Remember: The analyst is the decision-maker. You are an assistant."""
|
|
||||||
|
|
||||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
|
||||||
"""Provide guidance on artifact data and analysis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Request context including query and data context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guidance response with suggestions and reasoning.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If no provider is available.
|
|
||||||
"""
|
|
||||||
if not self.provider:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LLM provider available. Configure at least one of: "
|
|
||||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
|
||||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build prompt with context
|
|
||||||
prompt = self._build_prompt(context)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get guidance from LLM provider
|
|
||||||
guidance = await self.provider.generate(prompt, max_tokens=1024)
|
|
||||||
|
|
||||||
# Parse response into structured format
|
|
||||||
response = self._parse_response(guidance, context)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Agent assisted with query: {context.query[:50]}... "
|
|
||||||
f"(dataset: {context.dataset_name})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error generating guidance: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _build_prompt(self, context: AgentContext) -> str:
|
|
||||||
"""Build the prompt for the LLM."""
|
|
||||||
prompt_parts = [
|
|
||||||
f"Analyst query: {context.query}",
|
|
||||||
]
|
|
||||||
|
|
||||||
if context.dataset_name:
|
|
||||||
prompt_parts.append(f"Dataset: {context.dataset_name}")
|
|
||||||
|
|
||||||
if context.artifact_type:
|
|
||||||
prompt_parts.append(f"Artifact type: {context.artifact_type}")
|
|
||||||
|
|
||||||
if context.host_identifier:
|
|
||||||
prompt_parts.append(f"Host: {context.host_identifier}")
|
|
||||||
|
|
||||||
if context.data_summary:
|
|
||||||
prompt_parts.append(f"Data summary: {context.data_summary}")
|
|
||||||
|
|
||||||
if context.conversation_history:
|
|
||||||
prompt_parts.append("\nConversation history:")
|
|
||||||
for msg in context.conversation_history[-5:]: # Last 5 messages for context
|
|
||||||
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
|
||||||
|
|
||||||
return "\n".join(prompt_parts)
|
|
||||||
|
|
||||||
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
|
|
||||||
"""Parse LLM response into structured format.
|
|
||||||
|
|
||||||
Note: This is a simplified parser. In production, use structured output
|
|
||||||
from the LLM (JSON mode, function calling, etc.) for better reliability.
|
|
||||||
"""
|
|
||||||
# For now, return a structured response based on the raw guidance
|
|
||||||
# In production, parse JSON or use structured output from LLM
|
|
||||||
return AgentResponse(
|
|
||||||
guidance=response_text,
|
|
||||||
confidence=0.8, # Placeholder
|
|
||||||
suggested_pivots=[
|
|
||||||
"Analyze temporal patterns",
|
|
||||||
"Cross-reference with known indicators",
|
|
||||||
"Examine outliers in the dataset",
|
|
||||||
"Compare with baseline behavior",
|
|
||||||
],
|
|
||||||
suggested_filters=[
|
|
||||||
"Filter by high-risk indicators",
|
|
||||||
"Sort by timestamp for timeline analysis",
|
|
||||||
"Group by host or user",
|
|
||||||
"Filter by anomaly score",
|
|
||||||
],
|
|
||||||
caveats="Guidance is based on available data context. "
|
|
||||||
"Analysts should verify findings with additional sources.",
|
|
||||||
reasoning="Analysis generated based on artifact data patterns and analyst query.",
|
|
||||||
)
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
"""Pluggable LLM provider interface for analyst-assist agents.
|
|
||||||
|
|
||||||
Supports three provider types:
|
|
||||||
- Local: On-device or on-prem models
|
|
||||||
- Networked: Shared internal inference services
|
|
||||||
- Online: External hosted APIs
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class LLMProvider(ABC):
|
|
||||||
"""Abstract base class for LLM providers."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
|
||||||
"""Generate a response from the LLM.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The input prompt
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated text response
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
"""Check if provider backend is available."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocalProvider(LLMProvider):
|
|
||||||
"""Local LLM provider (on-device or on-prem models)."""
|
|
||||||
|
|
||||||
def __init__(self, model_path: Optional[str] = None):
|
|
||||||
"""Initialize local provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
|
|
||||||
"""
|
|
||||||
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
|
|
||||||
self.model = None
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
"""Check if local model is available."""
|
|
||||||
if not self.model_path:
|
|
||||||
return False
|
|
||||||
# In production, would verify model file exists and can be loaded
|
|
||||||
return os.path.exists(str(self.model_path))
|
|
||||||
|
|
||||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
|
||||||
"""Generate response using local model.
|
|
||||||
|
|
||||||
Note: This is a placeholder. In production, integrate with:
|
|
||||||
- llama-cpp-python for GGML models
|
|
||||||
- Ollama API
|
|
||||||
- vLLM
|
|
||||||
- Other local inference engines
|
|
||||||
"""
|
|
||||||
if not self.is_available():
|
|
||||||
raise RuntimeError("Local model not available")
|
|
||||||
|
|
||||||
# Placeholder implementation
|
|
||||||
return f"[Local model response to: {prompt[:50]}...]"
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkedProvider(LLMProvider):
|
|
||||||
"""Networked LLM provider (shared internal inference services)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_endpoint: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
model_name: str = "default",
|
|
||||||
):
|
|
||||||
"""Initialize networked provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
|
|
||||||
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
|
|
||||||
model_name: Model name/ID on the service.
|
|
||||||
"""
|
|
||||||
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
|
|
||||||
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
"""Check if networked service is available."""
|
|
||||||
return bool(self.api_endpoint)
|
|
||||||
|
|
||||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
|
||||||
"""Generate response using networked service.
|
|
||||||
|
|
||||||
Note: This is a placeholder. In production, integrate with:
|
|
||||||
- Internal inference service API
|
|
||||||
- LLM inference container cluster
|
|
||||||
- Enterprise inference gateway
|
|
||||||
"""
|
|
||||||
if not self.is_available():
|
|
||||||
raise RuntimeError("Networked service not available")
|
|
||||||
|
|
||||||
# Placeholder implementation
|
|
||||||
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
|
|
||||||
|
|
||||||
|
|
||||||
class OnlineProvider(LLMProvider):
|
|
||||||
"""Online LLM provider (external hosted APIs)."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_provider: str = "openai",
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
"""Initialize online provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_provider: Provider name (openai, anthropic, google, etc.)
|
|
||||||
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
|
|
||||||
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
|
|
||||||
"""
|
|
||||||
self.api_provider = api_provider
|
|
||||||
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
|
|
||||||
self.model_name = model_name or os.getenv(
|
|
||||||
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_available(self) -> bool:
|
|
||||||
"""Check if online API is available."""
|
|
||||||
return bool(self.api_key)
|
|
||||||
|
|
||||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
|
||||||
"""Generate response using online API.
|
|
||||||
|
|
||||||
Note: This is a placeholder. In production, integrate with:
|
|
||||||
- OpenAI API (GPT-3.5, GPT-4, etc.)
|
|
||||||
- Anthropic Claude API
|
|
||||||
- Google Gemini API
|
|
||||||
- Other hosted LLM services
|
|
||||||
"""
|
|
||||||
if not self.is_available():
|
|
||||||
raise RuntimeError("Online API not available or API key not set")
|
|
||||||
|
|
||||||
# Placeholder implementation
|
|
||||||
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
|
|
||||||
|
|
||||||
|
|
||||||
def get_provider(provider_type: str = "auto") -> LLMProvider:
|
|
||||||
"""Get an LLM provider based on configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
|
|
||||||
'auto' attempts to use the first available provider in order:
|
|
||||||
local -> networked -> online.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured LLM provider instance.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If no provider is available.
|
|
||||||
"""
|
|
||||||
# Explicit provider selection
|
|
||||||
if provider_type == "local":
|
|
||||||
provider = LocalProvider()
|
|
||||||
elif provider_type == "networked":
|
|
||||||
provider = NetworkedProvider()
|
|
||||||
elif provider_type == "online":
|
|
||||||
provider = OnlineProvider()
|
|
||||||
elif provider_type == "auto":
|
|
||||||
# Try providers in order of preference
|
|
||||||
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
|
|
||||||
provider = Provider()
|
|
||||||
if provider.is_available():
|
|
||||||
return provider
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LLM provider available. Configure at least one of: "
|
|
||||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
|
||||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
||||||
|
|
||||||
if not provider.is_available():
|
|
||||||
raise RuntimeError(f"{provider_type} provider not available")
|
|
||||||
|
|
||||||
return provider
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
"""API routes for analyst-assist agent."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
|
|
||||||
from app.agents.config import AgentConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
|
||||||
|
|
||||||
# Global agent instance (lazy-loaded)
|
|
||||||
_agent: ThreatHuntAgent | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent() -> ThreatHuntAgent:
|
|
||||||
"""Get or create the agent instance."""
|
|
||||||
global _agent
|
|
||||||
if _agent is None:
|
|
||||||
if not AgentConfig.is_agent_enabled():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail="Analyst-assist agent is not configured. "
|
|
||||||
"Please configure an LLM provider.",
|
|
||||||
)
|
|
||||||
_agent = ThreatHuntAgent()
|
|
||||||
return _agent
|
|
||||||
|
|
||||||
|
|
||||||
class AssistRequest(BaseModel):
|
|
||||||
"""Request for agent assistance."""
|
|
||||||
|
|
||||||
query: str = Field(
|
|
||||||
..., description="Analyst question or request for guidance"
|
|
||||||
)
|
|
||||||
dataset_name: str | None = Field(
|
|
||||||
None, description="Name of CSV dataset being analyzed"
|
|
||||||
)
|
|
||||||
artifact_type: str | None = Field(
|
|
||||||
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
|
|
||||||
)
|
|
||||||
host_identifier: str | None = Field(
|
|
||||||
None, description="Host name, IP address, or identifier"
|
|
||||||
)
|
|
||||||
data_summary: str | None = Field(
|
|
||||||
None, description="Brief summary or context about the uploaded data"
|
|
||||||
)
|
|
||||||
conversation_history: list[dict] | None = Field(
|
|
||||||
None, description="Previous messages for context"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AssistResponse(BaseModel):
|
|
||||||
"""Response with agent guidance."""
|
|
||||||
|
|
||||||
guidance: str
|
|
||||||
confidence: float
|
|
||||||
suggested_pivots: list[str]
|
|
||||||
suggested_filters: list[str]
|
|
||||||
caveats: str | None = None
|
|
||||||
reasoning: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/assist",
|
|
||||||
response_model=AssistResponse,
|
|
||||||
summary="Get analyst-assist guidance",
|
|
||||||
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
|
|
||||||
"Agent provides advisory guidance only - no execution.",
|
|
||||||
)
|
|
||||||
async def agent_assist(request: AssistRequest) -> AssistResponse:
|
|
||||||
"""Provide analyst-assist guidance on artifact data.
|
|
||||||
|
|
||||||
The agent will:
|
|
||||||
- Explain and interpret the provided data context
|
|
||||||
- Suggest analytical pivots the analyst might explore
|
|
||||||
- Suggest data filters or queries that might be useful
|
|
||||||
- Highlight assumptions, limitations, and caveats
|
|
||||||
|
|
||||||
The agent will NOT:
|
|
||||||
- Execute any tools or actions
|
|
||||||
- Escalate findings to alerts
|
|
||||||
- Modify any data or schema
|
|
||||||
- Make autonomous decisions
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Assistance request with query and context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Guidance response with suggestions and reasoning
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If agent is not configured (503) or request fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent = get_agent()
|
|
||||||
|
|
||||||
# Build context
|
|
||||||
context = AgentContext(
|
|
||||||
query=request.query,
|
|
||||||
dataset_name=request.dataset_name,
|
|
||||||
artifact_type=request.artifact_type,
|
|
||||||
host_identifier=request.host_identifier,
|
|
||||||
data_summary=request.data_summary,
|
|
||||||
conversation_history=request.conversation_history or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get guidance
|
|
||||||
response = await agent.assist(context)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Agent assisted analyst with query: {request.query[:50]}... "
|
|
||||||
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return AssistResponse(
|
|
||||||
guidance=response.guidance,
|
|
||||||
confidence=response.confidence,
|
|
||||||
suggested_pivots=response.suggested_pivots,
|
|
||||||
suggested_filters=response.suggested_filters,
|
|
||||||
caveats=response.caveats,
|
|
||||||
reasoning=response.reasoning,
|
|
||||||
)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.error(f"Agent error: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail=f"Agent unavailable: {str(e)}",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Unexpected error in agent_assist: {e}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Error generating guidance. Please try again.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/health",
|
|
||||||
summary="Check agent health",
|
|
||||||
description="Check if agent is configured and ready to assist.",
|
|
||||||
)
|
|
||||||
async def agent_health() -> dict:
|
|
||||||
"""Check agent availability and configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Health status with configuration details
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
agent = get_agent()
|
|
||||||
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
|
|
||||||
return {
|
|
||||||
"status": "healthy",
|
|
||||||
"provider": provider_type,
|
|
||||||
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
|
|
||||||
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
|
|
||||||
}
|
|
||||||
except HTTPException:
|
|
||||||
return {
|
|
||||||
"status": "unavailable",
|
|
||||||
"reason": "No LLM provider configured",
|
|
||||||
"configured_providers": {
|
|
||||||
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
|
|
||||||
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
|
|
||||||
"online": bool(AgentConfig.ONLINE_API_KEY),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""API routes for analyst-assist agent — v2.
|
"""API routes for analyst-assist agent v2.
|
||||||
|
|
||||||
Supports quick, deep, and debate modes with streaming.
|
Supports quick, deep, and debate modes with streaming.
|
||||||
Conversations are persisted to the database.
|
Conversations are persisted to the database.
|
||||||
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.db.models import Conversation, Message
|
from app.db.models import Conversation, Message, Dataset, KeywordTheme
|
||||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||||
from app.agents.providers_v2 import check_all_nodes
|
from app.agents.providers_v2 import check_all_nodes
|
||||||
from app.agents.registry import registry
|
from app.agents.registry import registry
|
||||||
from app.services.sans_rag import sans_rag
|
from app.services.sans_rag import sans_rag
|
||||||
|
from app.services.scanner import KeywordScanner
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
|
|||||||
return _agent
|
return _agent
|
||||||
|
|
||||||
|
|
||||||
# ── Request / Response models ─────────────────────────────────────────
|
# Request / Response models
|
||||||
|
|
||||||
|
|
||||||
class AssistRequest(BaseModel):
|
class AssistRequest(BaseModel):
|
||||||
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
|
|||||||
model_override: str | None = None
|
model_override: str | None = None
|
||||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||||
hunt_id: str | None = None
|
hunt_id: str | None = None
|
||||||
|
execution_preference: str = Field(default="auto", description="auto | force | off")
|
||||||
|
learning_mode: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AssistResponseModel(BaseModel):
|
class AssistResponseModel(BaseModel):
|
||||||
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
|
|||||||
node_used: str = ""
|
node_used: str = ""
|
||||||
latency_ms: int = 0
|
latency_ms: int = 0
|
||||||
perspectives: list[dict] | None = None
|
perspectives: list[dict] | None = None
|
||||||
|
execution: dict | None = None
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
|
||||||
|
POLICY_QUERY_TERMS = {
|
||||||
|
"policy", "violating", "violation", "browser history", "web history",
|
||||||
|
"domain", "domains", "adult", "gambling", "piracy", "aup",
|
||||||
|
}
|
||||||
|
WEB_DATASET_HINTS = {
|
||||||
|
"web", "history", "browser", "url", "visited_url", "domain", "title",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_policy_domain_query(query: str) -> bool:
|
||||||
|
q = (query or "").lower()
|
||||||
|
if not q:
|
||||||
|
return False
|
||||||
|
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
|
||||||
|
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
|
||||||
|
|
||||||
|
def _should_execute_policy_scan(request: AssistRequest) -> bool:
|
||||||
|
pref = (request.execution_preference or "auto").strip().lower()
|
||||||
|
if pref == "off":
|
||||||
|
return False
|
||||||
|
if pref == "force":
|
||||||
|
return True
|
||||||
|
return _is_policy_domain_query(request.query)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_domain(value: str | None) -> str | None:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
text = value.strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(text)
|
||||||
|
if parsed.netloc:
|
||||||
|
return parsed.netloc.lower()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
|
||||||
|
return m.group(0).lower() if m else None
|
||||||
|
|
||||||
|
|
||||||
|
def _dataset_score(ds: Dataset) -> int:
|
||||||
|
score = 0
|
||||||
|
name = (ds.name or "").lower()
|
||||||
|
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
|
||||||
|
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
|
||||||
|
|
||||||
|
for h in WEB_DATASET_HINTS:
|
||||||
|
if h in name:
|
||||||
|
score += 2
|
||||||
|
if h in cols_l:
|
||||||
|
score += 3
|
||||||
|
if h in norm_vals_l:
|
||||||
|
score += 3
|
||||||
|
|
||||||
|
if "visited_url" in cols_l or "url" in cols_l:
|
||||||
|
score += 8
|
||||||
|
if "user" in cols_l or "username" in cols_l:
|
||||||
|
score += 2
|
||||||
|
if "clientid" in cols_l or "fqdn" in cols_l:
|
||||||
|
score += 2
|
||||||
|
if (ds.row_count or 0) > 0:
|
||||||
|
score += 1
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
|
||||||
|
theme_result = await db.execute(
|
||||||
|
select(KeywordTheme).where(
|
||||||
|
KeywordTheme.enabled == True, # noqa: E712
|
||||||
|
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
themes = list(theme_result.scalars().all())
|
||||||
|
theme_ids = [t.id for t in themes]
|
||||||
|
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
|
||||||
|
|
||||||
|
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
|
||||||
|
if request.hunt_id:
|
||||||
|
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
|
||||||
|
ds_result = await db.execute(ds_query)
|
||||||
|
candidates = list(ds_result.scalars().all())
|
||||||
|
|
||||||
|
if request.dataset_name:
|
||||||
|
needle = request.dataset_name.lower().strip()
|
||||||
|
candidates = [d for d in candidates if needle in (d.name or "").lower()]
|
||||||
|
|
||||||
|
scored = sorted(
|
||||||
|
((d, _dataset_score(d)) for d in candidates),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
selected = [d for d, s in scored if s > 0][:8]
|
||||||
|
dataset_ids = [d.id for d in selected]
|
||||||
|
|
||||||
|
if not dataset_ids:
|
||||||
|
return {
|
||||||
|
"mode": "policy_scan",
|
||||||
|
"themes": theme_names,
|
||||||
|
"datasets_scanned": 0,
|
||||||
|
"dataset_names": [],
|
||||||
|
"total_hits": 0,
|
||||||
|
"policy_hits": 0,
|
||||||
|
"top_user_hosts": [],
|
||||||
|
"top_domains": [],
|
||||||
|
"sample_hits": [],
|
||||||
|
"note": "No suitable browser/web-history datasets found in current scope.",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await scanner.scan(
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
theme_ids=theme_ids or None,
|
||||||
|
scan_hunts=False,
|
||||||
|
scan_annotations=False,
|
||||||
|
scan_messages=False,
|
||||||
|
)
|
||||||
|
hits = result.get("hits", [])
|
||||||
|
|
||||||
|
user_host_counter = Counter()
|
||||||
|
domain_counter = Counter()
|
||||||
|
|
||||||
|
for h in hits:
|
||||||
|
user = h.get("username") or "(unknown-user)"
|
||||||
|
host = h.get("hostname") or "(unknown-host)"
|
||||||
|
user_host_counter[f"{user}|{host}"] += 1
|
||||||
|
|
||||||
|
dom = _extract_domain(h.get("matched_value"))
|
||||||
|
if dom:
|
||||||
|
domain_counter[dom] += 1
|
||||||
|
|
||||||
|
top_user_hosts = [
|
||||||
|
{"user_host": k, "count": v}
|
||||||
|
for k, v in user_host_counter.most_common(10)
|
||||||
|
]
|
||||||
|
top_domains = [
|
||||||
|
{"domain": k, "count": v}
|
||||||
|
for k, v in domain_counter.most_common(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mode": "policy_scan",
|
||||||
|
"themes": theme_names,
|
||||||
|
"datasets_scanned": len(dataset_ids),
|
||||||
|
"dataset_names": [d.name for d in selected],
|
||||||
|
"total_hits": int(result.get("total_hits", 0)),
|
||||||
|
"policy_hits": int(result.get("total_hits", 0)),
|
||||||
|
"rows_scanned": int(result.get("rows_scanned", 0)),
|
||||||
|
"top_user_hosts": top_user_hosts,
|
||||||
|
"top_domains": top_domains,
|
||||||
|
"sample_hits": hits[:20],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Routes
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -84,6 +252,76 @@ async def agent_assist(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
) -> AssistResponseModel:
|
) -> AssistResponseModel:
|
||||||
try:
|
try:
|
||||||
|
# Deterministic execution mode for policy-domain investigations.
|
||||||
|
if _should_execute_policy_scan(request):
|
||||||
|
t0 = time.monotonic()
|
||||||
|
exec_payload = await _run_policy_domain_execution(request, db)
|
||||||
|
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||||
|
|
||||||
|
policy_hits = exec_payload.get("policy_hits", 0)
|
||||||
|
datasets_scanned = exec_payload.get("datasets_scanned", 0)
|
||||||
|
|
||||||
|
if policy_hits > 0:
|
||||||
|
guidance = (
|
||||||
|
f"Policy-violation scan complete: {policy_hits} hits across "
|
||||||
|
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
|
||||||
|
f"in execution results for triage."
|
||||||
|
)
|
||||||
|
confidence = 0.95
|
||||||
|
caveats = "Keyword-based matching can include false positives; validate with full URL context."
|
||||||
|
else:
|
||||||
|
guidance = (
|
||||||
|
f"No policy-violation hits found in current scope "
|
||||||
|
f"({datasets_scanned} dataset(s) scanned)."
|
||||||
|
)
|
||||||
|
confidence = 0.9
|
||||||
|
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
|
||||||
|
|
||||||
|
response = AssistResponseModel(
|
||||||
|
guidance=guidance,
|
||||||
|
confidence=confidence,
|
||||||
|
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
|
||||||
|
suggested_filters=[
|
||||||
|
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
|
||||||
|
"username != null",
|
||||||
|
"hostname != null",
|
||||||
|
],
|
||||||
|
caveats=caveats,
|
||||||
|
reasoning=(
|
||||||
|
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
|
||||||
|
if _is_policy_domain_query(request.query)
|
||||||
|
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
|
||||||
|
),
|
||||||
|
sans_references=["SANS FOR508", "SANS SEC504"],
|
||||||
|
model_used="execution:keyword_scanner",
|
||||||
|
node_used="local",
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
execution=exec_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_id = request.conversation_id
|
||||||
|
if conv_id or request.hunt_id:
|
||||||
|
conv_id = await _persist_conversation(
|
||||||
|
db,
|
||||||
|
conv_id,
|
||||||
|
request,
|
||||||
|
AgentResponse(
|
||||||
|
guidance=response.guidance,
|
||||||
|
confidence=response.confidence,
|
||||||
|
suggested_pivots=response.suggested_pivots,
|
||||||
|
suggested_filters=response.suggested_filters,
|
||||||
|
caveats=response.caveats,
|
||||||
|
reasoning=response.reasoning,
|
||||||
|
sans_references=response.sans_references,
|
||||||
|
model_used=response.model_used,
|
||||||
|
node_used=response.node_used,
|
||||||
|
latency_ms=response.latency_ms,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response.conversation_id = conv_id
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
agent = get_agent()
|
agent = get_agent()
|
||||||
context = AgentContext(
|
context = AgentContext(
|
||||||
query=request.query,
|
query=request.query,
|
||||||
@@ -97,6 +335,7 @@ async def agent_assist(
|
|||||||
enrichment_summary=request.enrichment_summary,
|
enrichment_summary=request.enrichment_summary,
|
||||||
mode=request.mode,
|
mode=request.mode,
|
||||||
model_override=request.model_override,
|
model_override=request.model_override,
|
||||||
|
learning_mode=request.learning_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await agent.assist(context)
|
response = await agent.assist(context)
|
||||||
@@ -129,6 +368,7 @@ async def agent_assist(
|
|||||||
}
|
}
|
||||||
for p in response.perspectives
|
for p in response.perspectives
|
||||||
] if response.perspectives else None,
|
] if response.perspectives else None,
|
||||||
|
execution=None,
|
||||||
conversation_id=conv_id,
|
conversation_id=conv_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -208,7 +448,7 @@ async def list_models():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# ── Conversation persistence ──────────────────────────────────────────
|
# Conversation persistence
|
||||||
|
|
||||||
|
|
||||||
async def _persist_conversation(
|
async def _persist_conversation(
|
||||||
@@ -263,3 +503,4 @@ async def _persist_conversation(
|
|||||||
await db.flush()
|
await db.flush()
|
||||||
|
|
||||||
return conv.id
|
return conv.id
|
||||||
|
|
||||||
|
|||||||
@@ -381,6 +381,10 @@ async def submit_job(
|
|||||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not job_queue.can_accept():
|
||||||
|
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||||
|
if not job_queue.can_accept():
|
||||||
|
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||||
job = job_queue.submit(jt, **params)
|
job = job_queue.submit(jt, **params)
|
||||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""API routes for authentication — register, login, refresh, profile."""
|
"""API routes for authentication — register, login, refresh, profile."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
# ── Request / Response models ─────────────────────────────────────────
|
# ── Request / Response models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
class RegisterRequest(BaseModel):
|
||||||
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
|
|||||||
tokens: TokenPair
|
tokens: TokenPair
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|||||||
user = User(
|
user = User(
|
||||||
username=body.username,
|
username=body.username,
|
||||||
email=body.email,
|
email=body.email,
|
||||||
password_hash=hash_password(body.password),
|
hashed_password=hash_password(body.password),
|
||||||
display_name=body.display_name or body.username,
|
display_name=body.display_name or body.username,
|
||||||
role="analyst", # Default role
|
role="analyst", # Default role
|
||||||
)
|
)
|
||||||
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|||||||
result = await db.execute(select(User).where(User.username == body.username))
|
result = await db.execute(select(User).where(User.username == body.username))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user or not user.password_hash:
|
if not user or not user.hashed_password:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid username or password",
|
detail="Invalid username or password",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not verify_password(body.password, user.password_hash):
|
if not verify_password(body.password, user.hashed_password):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid username or password",
|
detail="Invalid username or password",
|
||||||
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
|
|||||||
if token_data.type != "refresh":
|
if token_data.type != "refresh":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Invalid token type — use refresh token",
|
detail="Invalid token type — use refresh token",
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||||
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
|
|||||||
is_active=user.is_active,
|
is_active=user.is_active,
|
||||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
|
from app.db.models import ProcessingTask
|
||||||
from app.db.repositories.datasets import DatasetRepository
|
from app.db.repositories.datasets import DatasetRepository
|
||||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||||
from app.services.normalizer import (
|
from app.services.normalizer import (
|
||||||
@@ -18,15 +19,20 @@ from app.services.normalizer import (
|
|||||||
detect_ioc_columns,
|
detect_ioc_columns,
|
||||||
detect_time_range,
|
detect_time_range,
|
||||||
)
|
)
|
||||||
|
from app.services.artifact_classifier import classify_artifact, get_artifact_category
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from app.services.job_queue import job_queue, JobType
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
from app.services.scanner import keyword_scan_cache
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||||
|
|
||||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||||
|
|
||||||
|
|
||||||
# ── Response models ───────────────────────────────────────────────────
|
# -- Response models --
|
||||||
|
|
||||||
|
|
||||||
class DatasetSummary(BaseModel):
|
class DatasetSummary(BaseModel):
|
||||||
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
|
|||||||
delimiter: str | None = None
|
delimiter: str | None = None
|
||||||
time_range_start: str | None = None
|
time_range_start: str | None = None
|
||||||
time_range_end: str | None = None
|
time_range_end: str | None = None
|
||||||
|
artifact_type: str | None = None
|
||||||
|
processing_status: str | None = None
|
||||||
hunt_id: str | None = None
|
hunt_id: str | None = None
|
||||||
created_at: str
|
created_at: str
|
||||||
|
|
||||||
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
|
|||||||
column_types: dict
|
column_types: dict
|
||||||
normalized_columns: dict
|
normalized_columns: dict
|
||||||
ioc_columns: dict
|
ioc_columns: dict
|
||||||
|
artifact_type: str | None = None
|
||||||
|
processing_status: str
|
||||||
|
jobs_queued: list[str]
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# -- Routes --
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
|
|||||||
response_model=UploadResponse,
|
response_model=UploadResponse,
|
||||||
summary="Upload a CSV dataset",
|
summary="Upload a CSV dataset",
|
||||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||||
"IOCs auto-detected, and rows stored in the database.",
|
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
|
||||||
)
|
)
|
||||||
async def upload_dataset(
|
async def upload_dataset(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -87,7 +98,7 @@ async def upload_dataset(
|
|||||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Upload and parse a CSV dataset."""
|
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
|
||||||
# Validate file
|
# Validate file
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise HTTPException(status_code=400, detail="No filename provided")
|
raise HTTPException(status_code=400, detail="No filename provided")
|
||||||
@@ -136,7 +147,12 @@ async def upload_dataset(
|
|||||||
# Detect time range
|
# Detect time range
|
||||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||||
|
|
||||||
# Store in DB
|
# Classify artifact type from column headers
|
||||||
|
artifact_type = classify_artifact(columns)
|
||||||
|
artifact_category = get_artifact_category(artifact_type)
|
||||||
|
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
|
||||||
|
|
||||||
|
# Store in DB with processing_status = "processing"
|
||||||
repo = DatasetRepository(db)
|
repo = DatasetRepository(db)
|
||||||
dataset = await repo.create_dataset(
|
dataset = await repo.create_dataset(
|
||||||
name=name or Path(file.filename).stem,
|
name=name or Path(file.filename).stem,
|
||||||
@@ -152,6 +168,8 @@ async def upload_dataset(
|
|||||||
time_range_start=time_start,
|
time_range_start=time_start,
|
||||||
time_range_end=time_end,
|
time_range_end=time_end,
|
||||||
hunt_id=hunt_id,
|
hunt_id=hunt_id,
|
||||||
|
artifact_type=artifact_type,
|
||||||
|
processing_status="processing",
|
||||||
)
|
)
|
||||||
|
|
||||||
await repo.bulk_insert_rows(
|
await repo.bulk_insert_rows(
|
||||||
@@ -162,9 +180,88 @@ async def upload_dataset(
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
|
||||||
|
f"artifact={artifact_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -- Queue full processing pipeline --
|
||||||
|
jobs_queued = []
|
||||||
|
|
||||||
|
task_rows: list[ProcessingTask] = []
|
||||||
|
|
||||||
|
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||||
|
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("triage")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=triage_job.id,
|
||||||
|
stage="triage",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 2. Anomaly detection (embedding-based outlier detection)
|
||||||
|
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("anomaly")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=anomaly_job.id,
|
||||||
|
stage="anomaly",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 3. AUP keyword scan
|
||||||
|
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("keyword_scan")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=kw_job.id,
|
||||||
|
stage="keyword_scan",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 4. IOC extraction
|
||||||
|
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||||
|
jobs_queued.append("ioc_extract")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=ioc_job.id,
|
||||||
|
stage="ioc_extract",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
# 5. Host inventory (network map) - requires hunt_id
|
||||||
|
if hunt_id:
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
jobs_queued.append("host_inventory")
|
||||||
|
task_rows.append(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
job_id=inv_job.id,
|
||||||
|
stage="host_inventory",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
|
||||||
|
if task_rows:
|
||||||
|
db.add_all(task_rows)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
id=dataset.id,
|
id=dataset.id,
|
||||||
name=dataset.name,
|
name=dataset.name,
|
||||||
@@ -173,7 +270,10 @@ async def upload_dataset(
|
|||||||
column_types=column_types,
|
column_types=column_types,
|
||||||
normalized_columns=column_mapping,
|
normalized_columns=column_mapping,
|
||||||
ioc_columns=ioc_columns,
|
ioc_columns=ioc_columns,
|
||||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
artifact_type=artifact_type,
|
||||||
|
processing_status="processing",
|
||||||
|
jobs_queued=jobs_queued,
|
||||||
|
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -208,6 +308,8 @@ async def list_datasets(
|
|||||||
delimiter=ds.delimiter,
|
delimiter=ds.delimiter,
|
||||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||||
|
artifact_type=ds.artifact_type,
|
||||||
|
processing_status=ds.processing_status,
|
||||||
hunt_id=ds.hunt_id,
|
hunt_id=ds.hunt_id,
|
||||||
created_at=ds.created_at.isoformat(),
|
created_at=ds.created_at.isoformat(),
|
||||||
)
|
)
|
||||||
@@ -244,6 +346,8 @@ async def get_dataset(
|
|||||||
delimiter=ds.delimiter,
|
delimiter=ds.delimiter,
|
||||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||||
|
artifact_type=ds.artifact_type,
|
||||||
|
processing_status=ds.processing_status,
|
||||||
hunt_id=ds.hunt_id,
|
hunt_id=ds.hunt_id,
|
||||||
created_at=ds.created_at.isoformat(),
|
created_at=ds.created_at.isoformat(),
|
||||||
)
|
)
|
||||||
@@ -292,4 +396,5 @@ async def delete_dataset(
|
|||||||
deleted = await repo.delete_dataset(dataset_id)
|
deleted = await repo.delete_dataset(dataset_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||||
|
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||||
return {"message": "Dataset deleted", "id": dataset_id}
|
return {"message": "Dataset deleted", "id": dataset_id}
|
||||||
|
|||||||
@@ -8,16 +8,15 @@ from sqlalchemy import select, func
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.db.models import Hunt, Conversation, Message
|
from app.db.models import Hunt, Dataset, ProcessingTask
|
||||||
|
from app.services.job_queue import job_queue
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||||
|
|
||||||
|
|
||||||
# ── Models ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class HuntCreate(BaseModel):
|
class HuntCreate(BaseModel):
|
||||||
name: str = Field(..., max_length=256)
|
name: str = Field(..., max_length=256)
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
|
|||||||
class HuntUpdate(BaseModel):
|
class HuntUpdate(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
status: str | None = None # active | closed | archived
|
status: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class HuntResponse(BaseModel):
|
class HuntResponse(BaseModel):
|
||||||
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
class HuntProgressResponse(BaseModel):
|
||||||
|
hunt_id: str
|
||||||
|
status: str
|
||||||
|
progress_percent: float
|
||||||
|
dataset_total: int
|
||||||
|
dataset_completed: int
|
||||||
|
dataset_processing: int
|
||||||
|
dataset_errors: int
|
||||||
|
active_jobs: int
|
||||||
|
queued_jobs: int
|
||||||
|
network_status: str
|
||||||
|
stages: dict
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||||
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||||
|
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
hunt = await db.get(Hunt, hunt_id)
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
|
||||||
|
ds_rows = await db.execute(
|
||||||
|
select(Dataset.id, Dataset.processing_status)
|
||||||
|
.where(Dataset.hunt_id == hunt_id)
|
||||||
|
)
|
||||||
|
datasets = ds_rows.all()
|
||||||
|
dataset_ids = {row[0] for row in datasets}
|
||||||
|
|
||||||
|
dataset_total = len(datasets)
|
||||||
|
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||||
|
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||||
|
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||||
|
|
||||||
|
jobs = job_queue.list_jobs(limit=5000)
|
||||||
|
relevant_jobs = [
|
||||||
|
j for j in jobs
|
||||||
|
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||||
|
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||||
|
]
|
||||||
|
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||||
|
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||||
|
|
||||||
|
task_rows = await db.execute(
|
||||||
|
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||||
|
.where(ProcessingTask.hunt_id == hunt_id)
|
||||||
|
)
|
||||||
|
tasks = task_rows.all()
|
||||||
|
|
||||||
|
task_total = len(tasks)
|
||||||
|
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||||
|
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||||
|
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||||
|
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||||
|
|
||||||
|
active_jobs = max(active_jobs_mem, task_running)
|
||||||
|
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||||
|
|
||||||
|
stage_rollup: dict[str, dict] = {}
|
||||||
|
for stage, status, progress in tasks:
|
||||||
|
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||||
|
bucket["total"] += 1
|
||||||
|
if status in ("completed", "failed", "cancelled"):
|
||||||
|
bucket["done"] += 1
|
||||||
|
elif status == "running":
|
||||||
|
bucket["running"] += 1
|
||||||
|
elif status == "queued":
|
||||||
|
bucket["queued"] += 1
|
||||||
|
bucket["progress_sum"] += float(progress or 0.0)
|
||||||
|
|
||||||
|
for stage_name, bucket in stage_rollup.items():
|
||||||
|
total = max(1, bucket["total"])
|
||||||
|
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||||
|
|
||||||
|
if inventory_cache.get(hunt_id) is not None:
|
||||||
|
network_status = "ready"
|
||||||
|
network_ratio = 1.0
|
||||||
|
elif inventory_cache.is_building(hunt_id):
|
||||||
|
network_status = "building"
|
||||||
|
network_ratio = 0.5
|
||||||
|
else:
|
||||||
|
network_status = "none"
|
||||||
|
network_ratio = 0.0
|
||||||
|
|
||||||
|
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||||
|
if task_ratio is None:
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||||
|
else:
|
||||||
|
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||||
|
progress_percent = round(overall_ratio * 100.0, 1)
|
||||||
|
|
||||||
|
status = "ready"
|
||||||
|
if dataset_total == 0:
|
||||||
|
status = "idle"
|
||||||
|
elif progress_percent < 100:
|
||||||
|
status = "processing"
|
||||||
|
|
||||||
|
stages = {
|
||||||
|
"datasets": {
|
||||||
|
"total": dataset_total,
|
||||||
|
"completed": dataset_completed,
|
||||||
|
"processing": dataset_processing,
|
||||||
|
"errors": dataset_errors,
|
||||||
|
"percent": round(dataset_ratio * 100.0, 1),
|
||||||
|
},
|
||||||
|
"network": {
|
||||||
|
"status": network_status,
|
||||||
|
"percent": round(network_ratio * 100.0, 1),
|
||||||
|
},
|
||||||
|
"jobs": {
|
||||||
|
"active": active_jobs,
|
||||||
|
"queued": queued_jobs,
|
||||||
|
"total_seen": len(relevant_jobs),
|
||||||
|
"task_total": task_total,
|
||||||
|
"task_done": task_done,
|
||||||
|
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||||
|
},
|
||||||
|
"task_stages": stage_rollup,
|
||||||
|
}
|
||||||
|
|
||||||
|
return HuntProgressResponse(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
status=status,
|
||||||
|
progress_percent=progress_percent,
|
||||||
|
dataset_total=dataset_total,
|
||||||
|
dataset_completed=dataset_completed,
|
||||||
|
dataset_processing=dataset_processing,
|
||||||
|
dataset_errors=dataset_errors,
|
||||||
|
active_jobs=active_jobs,
|
||||||
|
queued_jobs=queued_jobs,
|
||||||
|
network_status=network_status,
|
||||||
|
stages=stages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||||
async def update_hunt(
|
async def update_hunt(
|
||||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||||
|
|||||||
@@ -1,25 +1,21 @@
|
|||||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select, func, delete
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.db.models import KeywordTheme, Keyword
|
from app.db.models import KeywordTheme, Keyword
|
||||||
from app.services.scanner import KeywordScanner
|
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||||
|
|
||||||
|
|
||||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ThemeCreate(BaseModel):
|
class ThemeCreate(BaseModel):
|
||||||
name: str = Field(..., min_length=1, max_length=128)
|
name: str = Field(..., min_length=1, max_length=128)
|
||||||
color: str = Field(default="#9e9e9e", max_length=16)
|
color: str = Field(default="#9e9e9e", max_length=16)
|
||||||
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ScanRequest(BaseModel):
|
class ScanRequest(BaseModel):
|
||||||
dataset_ids: list[str] | None = None # None → all datasets
|
dataset_ids: list[str] | None = None
|
||||||
theme_ids: list[str] | None = None # None → all enabled themes
|
theme_ids: list[str] | None = None
|
||||||
scan_hunts: bool = True
|
scan_hunts: bool = False
|
||||||
scan_annotations: bool = True
|
scan_annotations: bool = False
|
||||||
scan_messages: bool = True
|
scan_messages: bool = False
|
||||||
|
prefer_cache: bool = True
|
||||||
|
force_rescan: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ScanHit(BaseModel):
|
class ScanHit(BaseModel):
|
||||||
theme_name: str
|
theme_name: str
|
||||||
theme_color: str
|
theme_color: str
|
||||||
keyword: str
|
keyword: str
|
||||||
source_type: str # dataset_row | hunt | annotation | message
|
source_type: str
|
||||||
source_id: str | int
|
source_id: str | int
|
||||||
field: str
|
field: str
|
||||||
matched_value: str
|
matched_value: str
|
||||||
row_index: int | None = None
|
row_index: int | None = None
|
||||||
dataset_name: str | None = None
|
dataset_name: str | None = None
|
||||||
|
hostname: str | None = None
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScanResponse(BaseModel):
|
class ScanResponse(BaseModel):
|
||||||
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
|
|||||||
themes_scanned: int
|
themes_scanned: int
|
||||||
keywords_scanned: int
|
keywords_scanned: int
|
||||||
rows_scanned: int
|
rows_scanned: int
|
||||||
|
cache_used: bool = False
|
||||||
|
cache_status: str = "miss"
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────
|
cached_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||||
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||||
|
hits: list[dict] = []
|
||||||
|
total_rows = 0
|
||||||
|
cached_at: str | None = None
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
result = entry["result"]
|
||||||
|
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||||
|
if entry.get("built_at"):
|
||||||
|
if not cached_at or entry["built_at"] > cached_at:
|
||||||
|
cached_at = entry["built_at"]
|
||||||
|
for h in result.get("hits", []):
|
||||||
|
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||||
|
continue
|
||||||
|
hits.append(h)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_hits": len(hits),
|
||||||
|
"hits": hits,
|
||||||
|
"rows_scanned": total_rows,
|
||||||
|
"cached_at": cached_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/themes", response_model=ThemeListResponse)
|
@router.get("/themes", response_model=ThemeListResponse)
|
||||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||||
"""List all keyword themes with their keywords."""
|
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||||
result = await db.execute(
|
|
||||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
|
||||||
)
|
|
||||||
themes = result.scalars().all()
|
themes = result.scalars().all()
|
||||||
return ThemeListResponse(
|
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||||
themes=[_theme_to_out(t) for t in themes],
|
|
||||||
total=len(themes),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||||
"""Create a new keyword theme."""
|
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||||
exists = await db.scalar(
|
|
||||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
|
||||||
)
|
|
||||||
if exists:
|
if exists:
|
||||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||||
db.add(theme)
|
db.add(theme)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await db.refresh(theme)
|
await db.refresh(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
return _theme_to_out(theme)
|
return _theme_to_out(theme)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||||
"""Update theme name, color, or enabled status."""
|
|
||||||
theme = await db.get(KeywordTheme, theme_id)
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
if not theme:
|
if not theme:
|
||||||
raise HTTPException(404, "Theme not found")
|
raise HTTPException(404, "Theme not found")
|
||||||
if body.name is not None:
|
if body.name is not None:
|
||||||
# check uniqueness
|
|
||||||
dup = await db.scalar(
|
dup = await db.scalar(
|
||||||
select(KeywordTheme.id).where(
|
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if dup:
|
if dup:
|
||||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||||
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
|
|||||||
theme.enabled = body.enabled
|
theme.enabled = body.enabled
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await db.refresh(theme)
|
await db.refresh(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
return _theme_to_out(theme)
|
return _theme_to_out(theme)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/themes/{theme_id}", status_code=204)
|
@router.delete("/themes/{theme_id}", status_code=204)
|
||||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
"""Delete a theme and all its keywords."""
|
|
||||||
theme = await db.get(KeywordTheme, theme_id)
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
if not theme:
|
if not theme:
|
||||||
raise HTTPException(404, "Theme not found")
|
raise HTTPException(404, "Theme not found")
|
||||||
await db.delete(theme)
|
await db.delete(theme)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
|
||||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||||
"""Add a single keyword to a theme."""
|
|
||||||
theme = await db.get(KeywordTheme, theme_id)
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
if not theme:
|
if not theme:
|
||||||
raise HTTPException(404, "Theme not found")
|
raise HTTPException(404, "Theme not found")
|
||||||
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
|||||||
db.add(kw)
|
db.add(kw)
|
||||||
await db.flush()
|
await db.flush()
|
||||||
await db.refresh(kw)
|
await db.refresh(kw)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
return KeywordOut(
|
return KeywordOut(
|
||||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||||
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
|||||||
|
|
||||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||||
"""Add multiple keywords to a theme at once."""
|
|
||||||
theme = await db.get(KeywordTheme, theme_id)
|
theme = await db.get(KeywordTheme, theme_id)
|
||||||
if not theme:
|
if not theme:
|
||||||
raise HTTPException(404, "Theme not found")
|
raise HTTPException(404, "Theme not found")
|
||||||
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
|
|||||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||||
added += 1
|
added += 1
|
||||||
await db.flush()
|
await db.flush()
|
||||||
|
keyword_scan_cache.clear()
|
||||||
return {"added": added, "theme_id": theme_id}
|
return {"added": added, "theme_id": theme_id}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||||
"""Delete a single keyword."""
|
|
||||||
kw = await db.get(Keyword, keyword_id)
|
kw = await db.get(Keyword, keyword_id)
|
||||||
if not kw:
|
if not kw:
|
||||||
raise HTTPException(404, "Keyword not found")
|
raise HTTPException(404, "Keyword not found")
|
||||||
await db.delete(kw)
|
await db.delete(kw)
|
||||||
|
keyword_scan_cache.clear()
|
||||||
|
|
||||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/scan", response_model=ScanResponse)
|
@router.post("/scan", response_model=ScanResponse)
|
||||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||||
"""Run AUP keyword scan across selected data sources."""
|
|
||||||
scanner = KeywordScanner(db)
|
scanner = KeywordScanner(db)
|
||||||
|
|
||||||
|
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||||
|
return {
|
||||||
|
"total_hits": 0,
|
||||||
|
"hits": [],
|
||||||
|
"themes_scanned": 0,
|
||||||
|
"keywords_scanned": 0,
|
||||||
|
"rows_scanned": 0,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
can_use_cache = (
|
||||||
|
body.prefer_cache
|
||||||
|
and not body.force_rescan
|
||||||
|
and bool(body.dataset_ids)
|
||||||
|
and not body.scan_hunts
|
||||||
|
and not body.scan_annotations
|
||||||
|
and not body.scan_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_use_cache:
|
||||||
|
themes = await scanner._load_themes(body.theme_ids)
|
||||||
|
allowed_theme_names = {t.name for t in themes}
|
||||||
|
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||||
|
|
||||||
|
cached_entries: list[dict] = []
|
||||||
|
missing: list[str] = []
|
||||||
|
for dataset_id in (body.dataset_ids or []):
|
||||||
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if not entry:
|
||||||
|
missing.append(dataset_id)
|
||||||
|
continue
|
||||||
|
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||||
|
|
||||||
|
if not missing and cached_entries:
|
||||||
|
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||||
|
merged = _merge_cached_results(
|
||||||
|
cached_entries + [{"result": partial, "built_at": None}],
|
||||||
|
allowed_theme_names if body.theme_ids else None,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"total_hits": merged["total_hits"],
|
||||||
|
"hits": merged["hits"],
|
||||||
|
"themes_scanned": len(themes),
|
||||||
|
"keywords_scanned": keywords_scanned,
|
||||||
|
"rows_scanned": merged["rows_scanned"],
|
||||||
|
"cache_used": len(cached_entries) > 0,
|
||||||
|
"cache_status": "partial" if cached_entries else "miss",
|
||||||
|
"cached_at": merged["cached_at"],
|
||||||
|
}
|
||||||
|
|
||||||
result = await scanner.scan(
|
result = await scanner.scan(
|
||||||
dataset_ids=body.dataset_ids,
|
dataset_ids=body.dataset_ids,
|
||||||
theme_ids=body.theme_ids,
|
theme_ids=body.theme_ids,
|
||||||
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
|||||||
scan_annotations=body.scan_annotations,
|
scan_annotations=body.scan_annotations,
|
||||||
scan_messages=body.scan_messages,
|
scan_messages=body.scan_messages,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/scan/quick", response_model=ScanResponse)
|
@router.get("/scan/quick", response_model=ScanResponse)
|
||||||
@@ -251,7 +326,22 @@ async def quick_scan(
|
|||||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Quick scan a single dataset with all enabled themes."""
|
entry = keyword_scan_cache.get(dataset_id)
|
||||||
|
if entry is not None:
|
||||||
|
result = entry.result
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": True,
|
||||||
|
"cache_status": "hit",
|
||||||
|
"cached_at": entry.built_at,
|
||||||
|
}
|
||||||
|
|
||||||
scanner = KeywordScanner(db)
|
scanner = KeywordScanner(db)
|
||||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
return result
|
keyword_scan_cache.put(dataset_id, result)
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"cache_used": False,
|
||||||
|
"cache_status": "miss",
|
||||||
|
"cached_at": None,
|
||||||
|
}
|
||||||
|
|||||||
146
backend/app/api/routes/mitre.py
Normal file
146
backend/app/api/routes/mitre.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""API routes for MITRE ATT&CK coverage visualization."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import (
|
||||||
|
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
|
||||||
|
|
||||||
|
# Canonical MITRE ATT&CK tactics in kill-chain order
|
||||||
|
TACTICS = [
|
||||||
|
"Reconnaissance", "Resource Development", "Initial Access",
|
||||||
|
"Execution", "Persistence", "Privilege Escalation",
|
||||||
|
"Defense Evasion", "Credential Access", "Discovery",
|
||||||
|
"Lateral Movement", "Collection", "Command and Control",
|
||||||
|
"Exfiltration", "Impact",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Simplified technique-to-tactic mapping (top techniques)
|
||||||
|
TECHNIQUE_TACTIC: dict[str, str] = {
|
||||||
|
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
|
||||||
|
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
|
||||||
|
"T1053": "Persistence", "T1053.005": "Persistence",
|
||||||
|
"T1547": "Persistence", "T1547.001": "Persistence",
|
||||||
|
"T1543": "Persistence", "T1543.003": "Persistence",
|
||||||
|
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
|
||||||
|
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
|
||||||
|
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
|
||||||
|
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
|
||||||
|
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
|
||||||
|
"T1070.004": "Defense Evasion",
|
||||||
|
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
|
||||||
|
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
|
||||||
|
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
|
||||||
|
"T1003": "Credential Access", "T1003.001": "Credential Access",
|
||||||
|
"T1110": "Credential Access", "T1558": "Credential Access",
|
||||||
|
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
|
||||||
|
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
|
||||||
|
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
|
||||||
|
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
|
||||||
|
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
|
||||||
|
"T1570": "Lateral Movement",
|
||||||
|
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
|
||||||
|
"T1071": "Command and Control", "T1071.001": "Command and Control",
|
||||||
|
"T1105": "Command and Control", "T1572": "Command and Control",
|
||||||
|
"T1095": "Command and Control",
|
||||||
|
"T1048": "Exfiltration", "T1041": "Exfiltration",
|
||||||
|
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
|
||||||
|
"T1566": "Initial Access", "T1566.001": "Initial Access",
|
||||||
|
"T1566.002": "Initial Access",
|
||||||
|
"T1190": "Initial Access", "T1133": "Initial Access",
|
||||||
|
"T1195": "Initial Access", "T1195.002": "Initial Access",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tactic(technique_id: str) -> str:
|
||||||
|
"""Map a technique ID to its tactic."""
|
||||||
|
tech = technique_id.strip().upper()
|
||||||
|
if tech in TECHNIQUE_TACTIC:
|
||||||
|
return TECHNIQUE_TACTIC[tech]
|
||||||
|
# Try parent technique
|
||||||
|
if "." in tech:
|
||||||
|
parent = tech.split(".")[0]
|
||||||
|
if parent in TECHNIQUE_TACTIC:
|
||||||
|
return TECHNIQUE_TACTIC[parent]
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/coverage")
|
||||||
|
async def get_mitre_coverage(
|
||||||
|
hunt_id: str | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
|
||||||
|
|
||||||
|
techniques: dict[str, dict] = {}
|
||||||
|
|
||||||
|
# Collect from triage results
|
||||||
|
triage_q = select(TriageResult)
|
||||||
|
if hunt_id:
|
||||||
|
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||||
|
result = await db.execute(triage_q.limit(500))
|
||||||
|
for t in result.scalars().all():
|
||||||
|
for tech in (t.mitre_techniques or []):
|
||||||
|
if tech not in techniques:
|
||||||
|
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||||
|
techniques[tech]["count"] += 1
|
||||||
|
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
|
||||||
|
|
||||||
|
# Collect from host profiles
|
||||||
|
profile_q = select(HostProfile)
|
||||||
|
if hunt_id:
|
||||||
|
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
|
||||||
|
result = await db.execute(profile_q.limit(200))
|
||||||
|
for p in result.scalars().all():
|
||||||
|
for tech in (p.mitre_techniques or []):
|
||||||
|
if tech not in techniques:
|
||||||
|
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||||
|
techniques[tech]["count"] += 1
|
||||||
|
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
|
||||||
|
|
||||||
|
# Collect from hypotheses
|
||||||
|
hyp_q = select(Hypothesis)
|
||||||
|
if hunt_id:
|
||||||
|
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
|
||||||
|
result = await db.execute(hyp_q.limit(200))
|
||||||
|
for h in result.scalars().all():
|
||||||
|
tech = h.mitre_technique
|
||||||
|
if tech:
|
||||||
|
if tech not in techniques:
|
||||||
|
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||||
|
techniques[tech]["count"] += 1
|
||||||
|
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
|
||||||
|
|
||||||
|
# Build tactic-grouped response
|
||||||
|
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
|
||||||
|
tactic_groups["Unknown"] = []
|
||||||
|
for tech in techniques.values():
|
||||||
|
tactic = tech["tactic"]
|
||||||
|
if tactic not in tactic_groups:
|
||||||
|
tactic_groups[tactic] = []
|
||||||
|
tactic_groups[tactic].append(tech)
|
||||||
|
|
||||||
|
total_techniques = len(techniques)
|
||||||
|
total_detections = sum(t["count"] for t in techniques.values())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"tactics": TACTICS,
|
||||||
|
"technique_count": total_techniques,
|
||||||
|
"detection_count": total_detections,
|
||||||
|
"tactic_coverage": {
|
||||||
|
t: {"techniques": techs, "count": len(techs)}
|
||||||
|
for t, techs in tactic_groups.items()
|
||||||
|
if techs
|
||||||
|
},
|
||||||
|
"all_techniques": list(techniques.values()),
|
||||||
|
}
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
"""Network topology API - host inventory endpoint."""
|
"""Network topology API - host inventory endpoint with background caching."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.services.host_inventory import build_host_inventory
|
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||||
|
from app.services.job_queue import job_queue, JobType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||||
@@ -15,14 +18,158 @@ router = APIRouter(prefix="/api/network", tags=["network"])
|
|||||||
@router.get("/host-inventory")
|
@router.get("/host-inventory")
|
||||||
async def get_host_inventory(
|
async def get_host_inventory(
|
||||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||||
|
force: bool = Query(False, description="Force rebuild, ignoring cache"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
"""Return a deduplicated host inventory for the hunt.
|
||||||
|
|
||||||
Returns unique hosts with hostname, IPs, OS, logged-in users, and
|
Returns instantly from cache if available (pre-built after upload or on startup).
|
||||||
network connections derived from netstat/connection data.
|
If cache is cold, triggers a background build and returns 202 so the
|
||||||
|
frontend can poll /inventory-status and re-request when ready.
|
||||||
"""
|
"""
|
||||||
result = await build_host_inventory(hunt_id, db)
|
# Force rebuild: invalidate cache, queue background job, return 202
|
||||||
if result["stats"]["total_hosts"] == 0:
|
if force:
|
||||||
return result
|
inventory_cache.invalidate(hunt_id)
|
||||||
return result
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "building", "message": "Rebuild queued"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try cache first
|
||||||
|
cached = inventory_cache.get(hunt_id)
|
||||||
|
if cached is not None:
|
||||||
|
logger.info(f"Serving cached host inventory for {hunt_id}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Cache miss: trigger background build instead of blocking for 90+ seconds
|
||||||
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
logger.info(f"Cache miss for {hunt_id}, triggering background build")
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "building", "message": "Inventory is being built in the background"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||||
|
hosts = inv.get("hosts", [])
|
||||||
|
conns = inv.get("connections", [])
|
||||||
|
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||||
|
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||||
|
return {
|
||||||
|
"stats": inv.get("stats", {}),
|
||||||
|
"top_hosts": [
|
||||||
|
{
|
||||||
|
"id": h.get("id"),
|
||||||
|
"hostname": h.get("hostname"),
|
||||||
|
"row_count": h.get("row_count", 0),
|
||||||
|
"ip_count": len(h.get("ips", [])),
|
||||||
|
"user_count": len(h.get("users", [])),
|
||||||
|
}
|
||||||
|
for h in top_hosts
|
||||||
|
],
|
||||||
|
"top_edges": top_edges,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||||
|
hosts = inv.get("hosts", [])
|
||||||
|
conns = inv.get("connections", [])
|
||||||
|
|
||||||
|
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||||
|
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||||
|
|
||||||
|
if node_id:
|
||||||
|
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||||
|
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||||
|
ids = {node_id}
|
||||||
|
for c in rel_edges:
|
||||||
|
ids.add(c.get("source"))
|
||||||
|
ids.add(c.get("target"))
|
||||||
|
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||||
|
else:
|
||||||
|
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||||
|
allowed = {h.get("id") for h in rel_hosts}
|
||||||
|
rel_edges = [
|
||||||
|
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||||
|
if c.get("source") in allowed and c.get("target") in allowed
|
||||||
|
][:max_edges]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hosts": rel_hosts,
|
||||||
|
"connections": rel_edges,
|
||||||
|
"stats": {
|
||||||
|
**inv.get("stats", {}),
|
||||||
|
"subgraph_hosts": len(rel_hosts),
|
||||||
|
"subgraph_connections": len(rel_edges),
|
||||||
|
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/summary")
|
||||||
|
async def get_inventory_summary(
|
||||||
|
hunt_id: str = Query(..., description="Hunt ID"),
|
||||||
|
top_n: int = Query(20, ge=1, le=200),
|
||||||
|
):
|
||||||
|
"""Return a lightweight summary view for large hunts."""
|
||||||
|
cached = inventory_cache.get(hunt_id)
|
||||||
|
if cached is None:
|
||||||
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||||
|
)
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return JSONResponse(status_code=202, content={"status": "building"})
|
||||||
|
return _build_summary(cached, top_n=top_n)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subgraph")
|
||||||
|
async def get_inventory_subgraph(
|
||||||
|
hunt_id: str = Query(..., description="Hunt ID"),
|
||||||
|
node_id: str | None = Query(None, description="Optional focal node"),
|
||||||
|
max_hosts: int = Query(200, ge=1, le=5000),
|
||||||
|
max_edges: int = Query(1500, ge=1, le=20000),
|
||||||
|
):
|
||||||
|
"""Return a bounded subgraph for scale-safe rendering."""
|
||||||
|
cached = inventory_cache.get(hunt_id)
|
||||||
|
if cached is None:
|
||||||
|
if not inventory_cache.is_building(hunt_id):
|
||||||
|
if job_queue.is_backlogged():
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=202,
|
||||||
|
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||||
|
)
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return JSONResponse(status_code=202, content={"status": "building"})
|
||||||
|
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/inventory-status")
|
||||||
|
async def get_inventory_status(
|
||||||
|
hunt_id: str = Query(..., description="Hunt ID to check"),
|
||||||
|
):
|
||||||
|
"""Check whether pre-computed host inventory is ready for a hunt.
|
||||||
|
|
||||||
|
Returns: { status: "ready" | "building" | "none" }
|
||||||
|
"""
|
||||||
|
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rebuild-inventory")
|
||||||
|
async def trigger_rebuild(
|
||||||
|
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
|
||||||
|
):
|
||||||
|
"""Trigger a background rebuild of the host inventory cache."""
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||||
|
return {"job_id": job.id, "status": "queued"}
|
||||||
|
|||||||
217
backend/app/api/routes/playbooks.py
Normal file
217
backend/app/api/routes/playbooks.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""API routes for investigation playbooks."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import Playbook, PlaybookStep
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
|
||||||
|
|
||||||
|
|
||||||
|
# -- Request / Response schemas ---
|
||||||
|
|
||||||
|
class StepCreate(BaseModel):
|
||||||
|
title: str
|
||||||
|
description: str | None = None
|
||||||
|
step_type: str = "manual"
|
||||||
|
target_route: str | None = None
|
||||||
|
|
||||||
|
class PlaybookCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
hunt_id: str | None = None
|
||||||
|
is_template: bool = False
|
||||||
|
steps: list[StepCreate] = []
|
||||||
|
|
||||||
|
class PlaybookUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
status: str | None = None
|
||||||
|
|
||||||
|
class StepUpdate(BaseModel):
|
||||||
|
is_completed: bool | None = None
|
||||||
|
notes: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# -- Default investigation templates ---
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATES = [
|
||||||
|
{
|
||||||
|
"name": "Standard Threat Hunt",
|
||||||
|
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
|
||||||
|
"steps": [
|
||||||
|
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
|
||||||
|
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
|
||||||
|
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
|
||||||
|
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
|
||||||
|
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
|
||||||
|
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
|
||||||
|
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
|
||||||
|
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
|
||||||
|
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
|
||||||
|
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Incident Response Triage",
|
||||||
|
"description": "Fast-track workflow for active incident response.",
|
||||||
|
"steps": [
|
||||||
|
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
|
||||||
|
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
|
||||||
|
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
|
||||||
|
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
|
||||||
|
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
|
||||||
|
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# -- Routes ---
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_playbooks(
|
||||||
|
include_templates: bool = True,
|
||||||
|
hunt_id: str | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(Playbook)
|
||||||
|
if hunt_id:
|
||||||
|
q = q.where(Playbook.hunt_id == hunt_id)
|
||||||
|
if not include_templates:
|
||||||
|
q = q.where(Playbook.is_template == False)
|
||||||
|
q = q.order_by(Playbook.created_at.desc())
|
||||||
|
result = await db.execute(q.limit(100))
|
||||||
|
playbooks = result.scalars().all()
|
||||||
|
|
||||||
|
return {"playbooks": [
|
||||||
|
{
|
||||||
|
"id": p.id, "name": p.name, "description": p.description,
|
||||||
|
"is_template": p.is_template, "hunt_id": p.hunt_id,
|
||||||
|
"status": p.status,
|
||||||
|
"total_steps": len(p.steps),
|
||||||
|
"completed_steps": sum(1 for s in p.steps if s.is_completed),
|
||||||
|
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||||
|
}
|
||||||
|
for p in playbooks
|
||||||
|
]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/templates")
|
||||||
|
async def get_templates():
|
||||||
|
"""Return built-in investigation templates."""
|
||||||
|
return {"templates": DEFAULT_TEMPLATES}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
pb = Playbook(
|
||||||
|
name=body.name,
|
||||||
|
description=body.description,
|
||||||
|
hunt_id=body.hunt_id,
|
||||||
|
is_template=body.is_template,
|
||||||
|
)
|
||||||
|
db.add(pb)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
created_steps = []
|
||||||
|
for i, step in enumerate(body.steps):
|
||||||
|
s = PlaybookStep(
|
||||||
|
playbook_id=pb.id,
|
||||||
|
order_index=i,
|
||||||
|
title=step.title,
|
||||||
|
description=step.description,
|
||||||
|
step_type=step.step_type,
|
||||||
|
target_route=step.target_route,
|
||||||
|
)
|
||||||
|
db.add(s)
|
||||||
|
created_steps.append(s)
|
||||||
|
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||||
|
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
|
||||||
|
"steps": [
|
||||||
|
{"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||||
|
"description": s.description, "step_type": s.step_type,
|
||||||
|
"target_route": s.target_route, "is_completed": False}
|
||||||
|
for s in created_steps
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{playbook_id}")
|
||||||
|
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||||
|
pb = result.scalar_one_or_none()
|
||||||
|
if not pb:
|
||||||
|
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||||
|
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
|
||||||
|
"status": pb.status,
|
||||||
|
"created_at": pb.created_at.isoformat() if pb.created_at else None,
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||||
|
"description": s.description, "step_type": s.step_type,
|
||||||
|
"target_route": s.target_route,
|
||||||
|
"is_completed": s.is_completed,
|
||||||
|
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
|
||||||
|
"notes": s.notes,
|
||||||
|
}
|
||||||
|
for s in sorted(pb.steps, key=lambda x: x.order_index)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{playbook_id}")
|
||||||
|
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||||
|
pb = result.scalar_one_or_none()
|
||||||
|
if not pb:
|
||||||
|
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
pb.name = body.name
|
||||||
|
if body.description is not None:
|
||||||
|
pb.description = body.description
|
||||||
|
if body.status is not None:
|
||||||
|
pb.status = body.status
|
||||||
|
return {"status": "updated"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{playbook_id}")
|
||||||
|
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||||
|
pb = result.scalar_one_or_none()
|
||||||
|
if not pb:
|
||||||
|
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||||
|
await db.delete(pb)
|
||||||
|
return {"status": "deleted"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/steps/{step_id}")
|
||||||
|
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
|
||||||
|
step = result.scalar_one_or_none()
|
||||||
|
if not step:
|
||||||
|
raise HTTPException(status_code=404, detail="Step not found")
|
||||||
|
|
||||||
|
if body.is_completed is not None:
|
||||||
|
step.is_completed = body.is_completed
|
||||||
|
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
|
||||||
|
if body.notes is not None:
|
||||||
|
step.notes = body.notes
|
||||||
|
return {"status": "updated", "is_completed": step.is_completed}
|
||||||
|
|
||||||
164
backend/app/api/routes/saved_searches.py
Normal file
164
backend/app/api/routes/saved_searches.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""API routes for saved searches and bookmarked queries."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import SavedSearch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
|
||||||
|
|
||||||
|
|
||||||
|
class SearchCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
|
||||||
|
query_params: dict
|
||||||
|
threshold: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
query_params: dict | None = None
|
||||||
|
threshold: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def list_searches(
|
||||||
|
search_type: str | None = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
|
||||||
|
if search_type:
|
||||||
|
q = q.where(SavedSearch.search_type == search_type)
|
||||||
|
result = await db.execute(q.limit(100))
|
||||||
|
searches = result.scalars().all()
|
||||||
|
return {"searches": [
|
||||||
|
{
|
||||||
|
"id": s.id, "name": s.name, "description": s.description,
|
||||||
|
"search_type": s.search_type, "query_params": s.query_params,
|
||||||
|
"threshold": s.threshold,
|
||||||
|
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||||
|
"last_result_count": s.last_result_count,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
for s in searches
|
||||||
|
]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
|
||||||
|
s = SavedSearch(
|
||||||
|
name=body.name,
|
||||||
|
description=body.description,
|
||||||
|
search_type=body.search_type,
|
||||||
|
query_params=body.query_params,
|
||||||
|
threshold=body.threshold,
|
||||||
|
)
|
||||||
|
db.add(s)
|
||||||
|
await db.flush()
|
||||||
|
return {
|
||||||
|
"id": s.id, "name": s.name, "search_type": s.search_type,
|
||||||
|
"query_params": s.query_params, "threshold": s.threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{search_id}")
|
||||||
|
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||||
|
s = result.scalar_one_or_none()
|
||||||
|
if not s:
|
||||||
|
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||||
|
return {
|
||||||
|
"id": s.id, "name": s.name, "description": s.description,
|
||||||
|
"search_type": s.search_type, "query_params": s.query_params,
|
||||||
|
"threshold": s.threshold,
|
||||||
|
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||||
|
"last_result_count": s.last_result_count,
|
||||||
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{search_id}")
|
||||||
|
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||||
|
s = result.scalar_one_or_none()
|
||||||
|
if not s:
|
||||||
|
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||||
|
if body.name is not None:
|
||||||
|
s.name = body.name
|
||||||
|
if body.description is not None:
|
||||||
|
s.description = body.description
|
||||||
|
if body.query_params is not None:
|
||||||
|
s.query_params = body.query_params
|
||||||
|
if body.threshold is not None:
|
||||||
|
s.threshold = body.threshold
|
||||||
|
return {"status": "updated"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{search_id}")
|
||||||
|
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||||
|
s = result.scalar_one_or_none()
|
||||||
|
if not s:
|
||||||
|
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||||
|
await db.delete(s)
|
||||||
|
return {"status": "deleted"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{search_id}/run")
|
||||||
|
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Execute a saved search and return results with delta from last run."""
|
||||||
|
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||||
|
s = result.scalar_one_or_none()
|
||||||
|
if not s:
|
||||||
|
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||||
|
|
||||||
|
previous_count = s.last_result_count or 0
|
||||||
|
results = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
if s.search_type == "ioc_search":
|
||||||
|
from app.db.models import EnrichmentResult
|
||||||
|
ioc_value = s.query_params.get("ioc_value", "")
|
||||||
|
if ioc_value:
|
||||||
|
q = select(EnrichmentResult).where(
|
||||||
|
EnrichmentResult.ioc_value.contains(ioc_value)
|
||||||
|
)
|
||||||
|
res = await db.execute(q.limit(100))
|
||||||
|
for er in res.scalars().all():
|
||||||
|
results.append({
|
||||||
|
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
|
||||||
|
"source": er.source, "verdict": er.verdict,
|
||||||
|
})
|
||||||
|
count = len(results)
|
||||||
|
|
||||||
|
elif s.search_type == "keyword_scan":
|
||||||
|
from app.db.models import KeywordTheme
|
||||||
|
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
|
||||||
|
themes = res.scalars().all()
|
||||||
|
count = sum(len(t.keywords) for t in themes)
|
||||||
|
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
|
||||||
|
|
||||||
|
# Update last run metadata
|
||||||
|
s.last_run_at = datetime.now(timezone.utc)
|
||||||
|
s.last_result_count = count
|
||||||
|
|
||||||
|
delta = count - previous_count
|
||||||
|
|
||||||
|
return {
|
||||||
|
"search_id": s.id, "search_name": s.name,
|
||||||
|
"search_type": s.search_type,
|
||||||
|
"result_count": count,
|
||||||
|
"previous_count": previous_count,
|
||||||
|
"delta": delta,
|
||||||
|
"results": results[:50],
|
||||||
|
}
|
||||||
184
backend/app/api/routes/stix_export.py
Normal file
184
backend/app/api/routes/stix_export.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""STIX 2.1 export endpoint.
|
||||||
|
|
||||||
|
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
|
||||||
|
STIX 2.1 Bundle JSON download. No external dependencies required we
|
||||||
|
build the JSON directly following the OASIS STIX 2.1 spec.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import (
|
||||||
|
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
|
||||||
|
EnrichmentResult, HuntReport,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/export", tags=["export"])
|
||||||
|
|
||||||
|
STIX_SPEC_VERSION = "2.1"
|
||||||
|
|
||||||
|
|
||||||
|
def _stix_id(stype: str) -> str:
|
||||||
|
return f"{stype}--{uuid.uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_identity(hunt_name: str) -> dict:
|
||||||
|
return {
|
||||||
|
"type": "identity",
|
||||||
|
"spec_version": STIX_SPEC_VERSION,
|
||||||
|
"id": _stix_id("identity"),
|
||||||
|
"created": _now_iso(),
|
||||||
|
"modified": _now_iso(),
|
||||||
|
"name": f"ThreatHunt - {hunt_name}",
|
||||||
|
"identity_class": "system",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
|
||||||
|
pattern_map = {
|
||||||
|
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
|
||||||
|
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
|
||||||
|
"domain": f"[domain-name:value = '{ioc_value}']",
|
||||||
|
"url": f"[url:value = '{ioc_value}']",
|
||||||
|
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
|
||||||
|
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
|
||||||
|
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
|
||||||
|
"email": f"[email-addr:value = '{ioc_value}']",
|
||||||
|
}
|
||||||
|
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
|
||||||
|
now = _now_iso()
|
||||||
|
return {
|
||||||
|
"type": "indicator",
|
||||||
|
"spec_version": STIX_SPEC_VERSION,
|
||||||
|
"id": _stix_id("indicator"),
|
||||||
|
"created": now,
|
||||||
|
"modified": now,
|
||||||
|
"name": f"{ioc_type}: {ioc_value}",
|
||||||
|
"pattern": pattern,
|
||||||
|
"pattern_type": "stix",
|
||||||
|
"valid_from": now,
|
||||||
|
"created_by_ref": identity_id,
|
||||||
|
"labels": [verdict or "suspicious"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
|
||||||
|
now = _now_iso()
|
||||||
|
return {
|
||||||
|
"type": "attack-pattern",
|
||||||
|
"spec_version": STIX_SPEC_VERSION,
|
||||||
|
"id": _stix_id("attack-pattern"),
|
||||||
|
"created": now,
|
||||||
|
"modified": now,
|
||||||
|
"name": technique_id,
|
||||||
|
"created_by_ref": identity_id,
|
||||||
|
"external_references": [
|
||||||
|
{
|
||||||
|
"source_name": "mitre-attack",
|
||||||
|
"external_id": technique_id,
|
||||||
|
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
|
||||||
|
now = _now_iso()
|
||||||
|
return {
|
||||||
|
"type": "report",
|
||||||
|
"spec_version": STIX_SPEC_VERSION,
|
||||||
|
"id": _stix_id("report"),
|
||||||
|
"created": now,
|
||||||
|
"modified": now,
|
||||||
|
"name": hyp.title,
|
||||||
|
"description": hyp.description or "",
|
||||||
|
"published": now,
|
||||||
|
"created_by_ref": identity_id,
|
||||||
|
"labels": ["threat-hunt-hypothesis"],
|
||||||
|
"object_refs": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stix/{hunt_id}")
|
||||||
|
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
|
||||||
|
# Fetch hunt
|
||||||
|
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(404, "Hunt not found")
|
||||||
|
|
||||||
|
identity = _build_identity(hunt.name)
|
||||||
|
objects: list[dict] = [identity]
|
||||||
|
seen_techniques: set[str] = set()
|
||||||
|
seen_iocs: set[str] = set()
|
||||||
|
|
||||||
|
# Gather IOCs from enrichment results for hunt's datasets
|
||||||
|
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
|
||||||
|
ds_ids = [r[0] for r in datasets_q.all()]
|
||||||
|
|
||||||
|
if ds_ids:
|
||||||
|
enrichments = (await db.execute(
|
||||||
|
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
|
||||||
|
)).scalars().all()
|
||||||
|
for e in enrichments:
|
||||||
|
key = f"{e.ioc_type}:{e.ioc_value}"
|
||||||
|
if key not in seen_iocs:
|
||||||
|
seen_iocs.add(key)
|
||||||
|
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
|
||||||
|
|
||||||
|
# Gather techniques from triage results
|
||||||
|
triages = (await db.execute(
|
||||||
|
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
|
||||||
|
)).scalars().all()
|
||||||
|
for t in triages:
|
||||||
|
for tech in (t.mitre_techniques or []):
|
||||||
|
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||||
|
if tid not in seen_techniques:
|
||||||
|
seen_techniques.add(tid)
|
||||||
|
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||||
|
|
||||||
|
# Gather techniques from host profiles
|
||||||
|
profiles = (await db.execute(
|
||||||
|
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
|
||||||
|
)).scalars().all()
|
||||||
|
for p in profiles:
|
||||||
|
for tech in (p.mitre_techniques or []):
|
||||||
|
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||||
|
if tid not in seen_techniques:
|
||||||
|
seen_techniques.add(tid)
|
||||||
|
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||||
|
|
||||||
|
# Gather hypotheses
|
||||||
|
hypos = (await db.execute(
|
||||||
|
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||||
|
)).scalars().all()
|
||||||
|
for h in hypos:
|
||||||
|
objects.append(_hypothesis_to_report(h, identity["id"]))
|
||||||
|
if h.mitre_technique and h.mitre_technique not in seen_techniques:
|
||||||
|
seen_techniques.add(h.mitre_technique)
|
||||||
|
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
|
||||||
|
|
||||||
|
bundle = {
|
||||||
|
"type": "bundle",
|
||||||
|
"id": _stix_id("bundle"),
|
||||||
|
"objects": objects,
|
||||||
|
}
|
||||||
|
|
||||||
|
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(bundle, indent=2),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||||
|
)
|
||||||
128
backend/app/api/routes/timeline.py
Normal file
128
backend/app/api/routes/timeline.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""API routes for forensic timeline visualization."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.db.models import Dataset, DatasetRow, Hunt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_timestamp(val: str | None) -> str | None:
|
||||||
|
"""Try to parse a timestamp string, return ISO format or None."""
|
||||||
|
if not val:
|
||||||
|
return None
|
||||||
|
val = str(val).strip()
|
||||||
|
if not val:
|
||||||
|
return None
|
||||||
|
# Try common formats
|
||||||
|
for fmt in [
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
|
||||||
|
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
|
||||||
|
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
|
||||||
|
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
return datetime.strptime(val, fmt).isoformat() + "Z"
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Columns likely to contain timestamps
|
||||||
|
TIME_COLUMNS = {
|
||||||
|
"timestamp", "time", "datetime", "date", "created", "modified",
|
||||||
|
"eventtime", "event_time", "start_time", "end_time",
|
||||||
|
"lastmodified", "last_modified", "created_at", "updated_at",
|
||||||
|
"mtime", "atime", "ctime", "btime",
|
||||||
|
"timecreated", "timegenerated", "sourcetime",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/hunt/{hunt_id}")
|
||||||
|
async def get_hunt_timeline(
|
||||||
|
hunt_id: str,
|
||||||
|
limit: int = 2000,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Build a timeline of events across all datasets in a hunt."""
|
||||||
|
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||||
|
hunt = result.scalar_one_or_none()
|
||||||
|
if not hunt:
|
||||||
|
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||||
|
|
||||||
|
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||||
|
datasets = result.scalars().all()
|
||||||
|
if not datasets:
|
||||||
|
return {"hunt_id": hunt_id, "events": [], "datasets": []}
|
||||||
|
|
||||||
|
events = []
|
||||||
|
dataset_info = []
|
||||||
|
|
||||||
|
for ds in datasets:
|
||||||
|
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||||
|
dataset_info.append({
|
||||||
|
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
|
||||||
|
"row_count": ds.row_count,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Find time columns for this dataset
|
||||||
|
schema = ds.column_schema or {}
|
||||||
|
time_cols = []
|
||||||
|
for col in (ds.normalized_columns or {}).values():
|
||||||
|
if col.lower() in TIME_COLUMNS:
|
||||||
|
time_cols.append(col)
|
||||||
|
if not time_cols:
|
||||||
|
for col in schema:
|
||||||
|
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
|
||||||
|
time_cols.append(col)
|
||||||
|
if not time_cols:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fetch rows
|
||||||
|
rows_result = await db.execute(
|
||||||
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds.id)
|
||||||
|
.order_by(DatasetRow.row_index)
|
||||||
|
.limit(limit // max(len(datasets), 1))
|
||||||
|
)
|
||||||
|
for r in rows_result.scalars().all():
|
||||||
|
data = r.normalized_data or r.data
|
||||||
|
ts = None
|
||||||
|
for tc in time_cols:
|
||||||
|
ts = _parse_timestamp(data.get(tc))
|
||||||
|
if ts:
|
||||||
|
break
|
||||||
|
if ts:
|
||||||
|
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
|
||||||
|
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
|
||||||
|
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
|
||||||
|
events.append({
|
||||||
|
"timestamp": ts,
|
||||||
|
"dataset_id": ds.id,
|
||||||
|
"dataset_name": ds.name,
|
||||||
|
"artifact_type": artifact_type,
|
||||||
|
"row_index": r.row_index,
|
||||||
|
"hostname": str(hostname)[:128],
|
||||||
|
"process": str(process)[:128],
|
||||||
|
"summary": str(summary)[:256],
|
||||||
|
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by timestamp
|
||||||
|
events.sort(key=lambda e: e["timestamp"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hunt_id": hunt_id,
|
||||||
|
"hunt_name": hunt.name,
|
||||||
|
"event_count": len(events),
|
||||||
|
"datasets": dataset_info,
|
||||||
|
"events": events[:limit],
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Application configuration — single source of truth for all settings.
|
"""Application configuration - single source of truth for all settings.
|
||||||
|
|
||||||
Loads from environment variables with sensible defaults for local dev.
|
Loads from environment variables with sensible defaults for local dev.
|
||||||
"""
|
"""
|
||||||
@@ -13,12 +13,12 @@ from pydantic import Field
|
|||||||
class AppConfig(BaseSettings):
|
class AppConfig(BaseSettings):
|
||||||
"""Central configuration for the entire ThreatHunt application."""
|
"""Central configuration for the entire ThreatHunt application."""
|
||||||
|
|
||||||
# ── General ────────────────────────────────────────────────────────
|
# -- General --------------------------------------------------------
|
||||||
APP_NAME: str = "ThreatHunt"
|
APP_NAME: str = "ThreatHunt"
|
||||||
APP_VERSION: str = "0.3.0"
|
APP_VERSION: str = "0.3.0"
|
||||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||||
|
|
||||||
# ── Database ───────────────────────────────────────────────────────
|
# -- Database -------------------------------------------------------
|
||||||
DATABASE_URL: str = Field(
|
DATABASE_URL: str = Field(
|
||||||
default="sqlite+aiosqlite:///./threathunt.db",
|
default="sqlite+aiosqlite:///./threathunt.db",
|
||||||
description="Async SQLAlchemy database URL. "
|
description="Async SQLAlchemy database URL. "
|
||||||
@@ -26,17 +26,17 @@ class AppConfig(BaseSettings):
|
|||||||
"postgresql+asyncpg://user:pass@host/db for production.",
|
"postgresql+asyncpg://user:pass@host/db for production.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── CORS ───────────────────────────────────────────────────────────
|
# -- CORS -----------------------------------------------------------
|
||||||
ALLOWED_ORIGINS: str = Field(
|
ALLOWED_ORIGINS: str = Field(
|
||||||
default="http://localhost:3000,http://localhost:8000",
|
default="http://localhost:3000,http://localhost:8000",
|
||||||
description="Comma-separated list of allowed CORS origins",
|
description="Comma-separated list of allowed CORS origins",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── File uploads ───────────────────────────────────────────────────
|
# -- File uploads ---------------------------------------------------
|
||||||
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
||||||
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
||||||
|
|
||||||
# ── LLM Cluster — Wile & Roadrunner ────────────────────────────────
|
# -- LLM Cluster - Wile & Roadrunner --------------------------------
|
||||||
OPENWEBUI_URL: str = Field(
|
OPENWEBUI_URL: str = Field(
|
||||||
default="https://ai.guapo613.beer",
|
default="https://ai.guapo613.beer",
|
||||||
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
||||||
@@ -58,7 +58,7 @@ class AppConfig(BaseSettings):
|
|||||||
default=11434, description="Ollama port on Roadrunner"
|
default=11434, description="Ollama port on Roadrunner"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── LLM Routing defaults ──────────────────────────────────────────
|
# -- LLM Routing defaults ------------------------------------------
|
||||||
DEFAULT_FAST_MODEL: str = Field(
|
DEFAULT_FAST_MODEL: str = Field(
|
||||||
default="llama3.1:latest",
|
default="llama3.1:latest",
|
||||||
description="Default model for quick chat / simple queries",
|
description="Default model for quick chat / simple queries",
|
||||||
@@ -80,18 +80,18 @@ class AppConfig(BaseSettings):
|
|||||||
description="Default embedding model",
|
description="Default embedding model",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Agent behaviour ───────────────────────────────────────────────
|
# -- Agent behaviour ------------------------------------------------
|
||||||
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
||||||
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
||||||
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
||||||
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
||||||
|
|
||||||
# ── Enrichment API keys ───────────────────────────────────────────
|
# -- Enrichment API keys --------------------------------------------
|
||||||
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
||||||
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
||||||
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────
|
# -- Auth -----------------------------------------------------------
|
||||||
JWT_SECRET: str = Field(
|
JWT_SECRET: str = Field(
|
||||||
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
||||||
description="Secret for JWT signing",
|
description="Secret for JWT signing",
|
||||||
@@ -99,6 +99,73 @@ class AppConfig(BaseSettings):
|
|||||||
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
||||||
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
||||||
|
|
||||||
|
# -- Triage settings ------------------------------------------------
|
||||||
|
TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch")
|
||||||
|
TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field(
|
||||||
|
default=200, description="Stop triage after this many suspicious rows"
|
||||||
|
)
|
||||||
|
TRIAGE_ESCALATION_THRESHOLD: float = Field(
|
||||||
|
default=5.0, description="Risk score threshold for escalation counting"
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Host profiler settings -----------------------------------------
|
||||||
|
HOST_PROFILE_CONCURRENCY: int = Field(
|
||||||
|
default=3, description="Max concurrent host profile LLM calls"
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Scanner settings -----------------------------------------------
|
||||||
|
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||||
|
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Job queue settings ----------------------------------------------
|
||||||
|
JOB_QUEUE_MAX_BACKLOG: int = Field(
|
||||||
|
default=2000, description="Soft cap for queued background jobs"
|
||||||
|
)
|
||||||
|
JOB_QUEUE_RETAIN_COMPLETED: int = Field(
|
||||||
|
default=3000, description="Maximum completed/failed jobs to retain in memory"
|
||||||
|
)
|
||||||
|
JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(
|
||||||
|
default=60, description="How often to run in-memory job cleanup"
|
||||||
|
)
|
||||||
|
JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(
|
||||||
|
default=3600, description="Age threshold for in-memory completed job cleanup"
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Startup throttling ------------------------------------------------
|
||||||
|
STARTUP_WARMUP_MAX_HUNTS: int = Field(
|
||||||
|
default=5, description="Max hunts to warm inventory cache for at startup"
|
||||||
|
)
|
||||||
|
STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||||
|
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||||
|
)
|
||||||
|
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Mark stale queued/running processing tasks as failed on startup",
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Network API scale guards -----------------------------------------
|
||||||
|
NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(
|
||||||
|
default=400, description="Hard cap for hosts returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||||
|
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||||
|
default=5000,
|
||||||
|
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
|
||||||
|
)
|
||||||
|
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
|
||||||
|
default=120000,
|
||||||
|
description="Max unique connection tuples retained during host inventory build",
|
||||||
|
)
|
||||||
|
|
||||||
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -119,3 +186,4 @@ class AppConfig(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
settings = AppConfig()
|
settings = AppConfig()
|
||||||
|
|
||||||
|
|||||||
@@ -21,9 +21,14 @@ _engine_kwargs: dict = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if _is_sqlite:
|
if _is_sqlite:
|
||||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||||
_engine_kwargs["pool_size"] = 1
|
# NullPool: each session gets its own connection.
|
||||||
_engine_kwargs["max_overflow"] = 0
|
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
_engine_kwargs["poolclass"] = NullPool
|
||||||
|
else:
|
||||||
|
_engine_kwargs["pool_size"] = 5
|
||||||
|
_engine_kwargs["max_overflow"] = 10
|
||||||
|
|
||||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||||
|
|
||||||
@@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
|||||||
if _is_sqlite:
|
if _is_sqlite:
|
||||||
cursor = dbapi_conn.cursor()
|
cursor = dbapi_conn.cursor()
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
cursor.execute("PRAGMA busy_timeout=5000")
|
cursor.execute("PRAGMA busy_timeout=30000")
|
||||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
@@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Alias expected by other modules
|
||||||
|
async_session = async_session_factory
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
"""Base class for all ORM models."""
|
"""Base class for all ORM models."""
|
||||||
pass
|
pass
|
||||||
@@ -71,5 +80,5 @@ async def init_db() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def dispose_db() -> None:
|
async def dispose_db() -> None:
|
||||||
"""Dispose of the engine connection pool."""
|
"""Dispose of the engine on shutdown."""
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""SQLAlchemy ORM models for ThreatHunt.
|
"""SQLAlchemy ORM models for ThreatHunt.
|
||||||
|
|
||||||
All persistent entities: datasets, hunts, conversations, annotations,
|
All persistent entities: datasets, hunts, conversations, annotations,
|
||||||
hypotheses, enrichment results, users, and AI analysis tables.
|
hypotheses, enrichment results, users, and AI analysis tables.
|
||||||
@@ -43,6 +43,7 @@ class User(Base):
|
|||||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||||
role: Mapped[str] = mapped_column(String(16), default="analyst")
|
role: Mapped[str] = mapped_column(String(16), default="analyst")
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
|
||||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||||
@@ -400,3 +401,107 @@ class AnomalyResult(Base):
|
|||||||
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Persistent Processing Tasks (Phase 2) ---
|
||||||
|
|
||||||
|
class ProcessingTask(Base):
|
||||||
|
__tablename__ = "processing_tasks"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||||
|
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
|
||||||
|
)
|
||||||
|
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
|
||||||
|
)
|
||||||
|
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||||
|
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
|
||||||
|
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||||
|
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
|
||||||
|
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Playbook / Investigation Templates (Feature 3) ---
|
||||||
|
|
||||||
|
class Playbook(Base):
|
||||||
|
__tablename__ = "playbooks"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||||
|
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
created_by: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("users.id"), nullable=True
|
||||||
|
)
|
||||||
|
is_template: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("hunts.id"), nullable=True
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
steps: Mapped[list["PlaybookStep"]] = relationship(
|
||||||
|
back_populates="playbook", lazy="selectin", cascade="all, delete-orphan",
|
||||||
|
order_by="PlaybookStep.order_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlaybookStep(Base):
|
||||||
|
__tablename__ = "playbook_steps"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
playbook_id: Mapped[str] = mapped_column(
|
||||||
|
String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
|
||||||
|
)
|
||||||
|
order_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
step_type: Mapped[str] = mapped_column(String(32), default="manual")
|
||||||
|
target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||||
|
is_completed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
playbook: Mapped["Playbook"] = relationship(back_populates="steps")
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_playbook_steps_playbook", "playbook_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -- Saved Searches (Feature 5) ---
|
||||||
|
|
||||||
|
class SavedSearch(Base):
|
||||||
|
__tablename__ = "saved_searches"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||||
|
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
search_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
|
query_params: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||||
|
threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||||
|
created_by: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(32), ForeignKey("users.id"), nullable=True
|
||||||
|
)
|
||||||
|
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_saved_searches_type", "search_type"),
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ from app.api.routes.auth import router as auth_router
|
|||||||
from app.api.routes.keywords import router as keywords_router
|
from app.api.routes.keywords import router as keywords_router
|
||||||
from app.api.routes.analysis import router as analysis_router
|
from app.api.routes.analysis import router as analysis_router
|
||||||
from app.api.routes.network import router as network_router
|
from app.api.routes.network import router as network_router
|
||||||
|
from app.api.routes.mitre import router as mitre_router
|
||||||
|
from app.api.routes.timeline import router as timeline_router
|
||||||
|
from app.api.routes.playbooks import router as playbooks_router
|
||||||
|
from app.api.routes.saved_searches import router as searches_router
|
||||||
|
from app.api.routes.stix_export import router as stix_router
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,13 +52,80 @@ async def lifespan(app: FastAPI):
|
|||||||
await seed_defaults(seed_db)
|
await seed_defaults(seed_db)
|
||||||
logger.info("AUP keyword defaults checked")
|
logger.info("AUP keyword defaults checked")
|
||||||
|
|
||||||
# Start job queue (Phase 10)
|
# Start job queue
|
||||||
from app.services.job_queue import job_queue, register_all_handlers
|
from app.services.job_queue import (
|
||||||
|
job_queue,
|
||||||
|
register_all_handlers,
|
||||||
|
reconcile_stale_processing_tasks,
|
||||||
|
JobType,
|
||||||
|
)
|
||||||
|
|
||||||
|
if settings.STARTUP_RECONCILE_STALE_TASKS:
|
||||||
|
reconciled = await reconcile_stale_processing_tasks()
|
||||||
|
if reconciled:
|
||||||
|
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
|
||||||
|
|
||||||
register_all_handlers()
|
register_all_handlers()
|
||||||
await job_queue.start()
|
await job_queue.start()
|
||||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||||
|
|
||||||
# Start load balancer health loop (Phase 10)
|
# Pre-warm host inventory cache for existing hunts
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
async with async_session_factory() as warm_db:
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from app.db.models import Hunt, Dataset
|
||||||
|
stmt = (
|
||||||
|
select(Hunt.id)
|
||||||
|
.join(Dataset, Dataset.hunt_id == Hunt.id)
|
||||||
|
.group_by(Hunt.id)
|
||||||
|
.having(func.count(Dataset.id) > 0)
|
||||||
|
)
|
||||||
|
result = await warm_db.execute(stmt)
|
||||||
|
hunt_ids = [row[0] for row in result.all()]
|
||||||
|
warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]
|
||||||
|
for hid in warm_hunts:
|
||||||
|
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)
|
||||||
|
if warm_hunts:
|
||||||
|
logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})")
|
||||||
|
|
||||||
|
# Check which datasets still need processing
|
||||||
|
# (no anomaly results = never fully processed)
|
||||||
|
async with async_session_factory() as reprocess_db:
|
||||||
|
from sqlalchemy import select, exists
|
||||||
|
from app.db.models import Dataset, AnomalyResult
|
||||||
|
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||||
|
has_anomaly = (
|
||||||
|
select(AnomalyResult.id)
|
||||||
|
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||||
|
.limit(1)
|
||||||
|
.correlate(Dataset)
|
||||||
|
.exists()
|
||||||
|
)
|
||||||
|
stmt = select(Dataset.id).where(~has_anomaly)
|
||||||
|
result = await reprocess_db.execute(stmt)
|
||||||
|
unprocessed_ids = [row[0] for row in result.all()]
|
||||||
|
|
||||||
|
if unprocessed_ids:
|
||||||
|
to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]
|
||||||
|
for ds_id in to_reprocess:
|
||||||
|
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||||
|
logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})")
|
||||||
|
async with async_session_factory() as update_db:
|
||||||
|
from sqlalchemy import update
|
||||||
|
from app.db.models import Dataset
|
||||||
|
await update_db.execute(
|
||||||
|
update(Dataset)
|
||||||
|
.where(Dataset.id.in_(to_reprocess))
|
||||||
|
.values(processing_status="processing")
|
||||||
|
)
|
||||||
|
await update_db.commit()
|
||||||
|
else:
|
||||||
|
logger.info("All datasets already processed - skipping startup pipeline")
|
||||||
|
|
||||||
|
# Start load balancer health loop
|
||||||
from app.services.load_balancer import lb
|
from app.services.load_balancer import lb
|
||||||
await lb.start_health_loop(interval=30.0)
|
await lb.start_health_loop(interval=30.0)
|
||||||
logger.info("Load balancer health loop started")
|
logger.info("Load balancer health loop started")
|
||||||
@@ -61,12 +133,10 @@ async def lifespan(app: FastAPI):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info("Shutting down ...")
|
logger.info("Shutting down ...")
|
||||||
# Stop job queue
|
|
||||||
from app.services.job_queue import job_queue as jq
|
from app.services.job_queue import job_queue as jq
|
||||||
await jq.stop()
|
await jq.stop()
|
||||||
logger.info("Job queue stopped")
|
logger.info("Job queue stopped")
|
||||||
|
|
||||||
# Stop load balancer
|
|
||||||
from app.services.load_balancer import lb as _lb
|
from app.services.load_balancer import lb as _lb
|
||||||
await _lb.stop_health_loop()
|
await _lb.stop_health_loop()
|
||||||
logger.info("Load balancer stopped")
|
logger.info("Load balancer stopped")
|
||||||
@@ -106,6 +176,11 @@ app.include_router(reports_router)
|
|||||||
app.include_router(keywords_router)
|
app.include_router(keywords_router)
|
||||||
app.include_router(analysis_router)
|
app.include_router(analysis_router)
|
||||||
app.include_router(network_router)
|
app.include_router(network_router)
|
||||||
|
app.include_router(mitre_router)
|
||||||
|
app.include_router(timeline_router)
|
||||||
|
app.include_router(playbooks_router)
|
||||||
|
app.include_router(searches_router)
|
||||||
|
app.include_router(stix_router)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/", tags=["health"])
|
@app.get("/", tags=["health"])
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, func
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db.models import Dataset, DatasetRow
|
from app.db.models import Dataset, DatasetRow
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
|
|||||||
return name or ''
|
return name or ''
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# In-memory host inventory cache
|
||||||
|
# Pre-computed results stored per hunt_id, built in background after upload.
|
||||||
|
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
class _InventoryCache:
|
||||||
|
"""Simple in-memory cache for pre-computed host inventories."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._data: dict[str, dict] = {} # hunt_id -> result dict
|
||||||
|
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
|
||||||
|
self._building: set[str] = set() # hunt_ids currently being built
|
||||||
|
|
||||||
|
def get(self, hunt_id: str) -> dict | None:
|
||||||
|
"""Return cached result if present. Never expires; only invalidated on new upload."""
|
||||||
|
return self._data.get(hunt_id)
|
||||||
|
|
||||||
|
def put(self, hunt_id: str, result: dict):
|
||||||
|
self._data[hunt_id] = result
|
||||||
|
self._timestamps[hunt_id] = _time.time()
|
||||||
|
self._building.discard(hunt_id)
|
||||||
|
logger.info(f"Cached host inventory for hunt {hunt_id} "
|
||||||
|
f"({result['stats']['total_hosts']} hosts)")
|
||||||
|
|
||||||
|
def invalidate(self, hunt_id: str):
|
||||||
|
self._data.pop(hunt_id, None)
|
||||||
|
self._timestamps.pop(hunt_id, None)
|
||||||
|
|
||||||
|
def is_building(self, hunt_id: str) -> bool:
|
||||||
|
return hunt_id in self._building
|
||||||
|
|
||||||
|
def set_building(self, hunt_id: str):
|
||||||
|
self._building.add(hunt_id)
|
||||||
|
|
||||||
|
def clear_building(self, hunt_id: str):
|
||||||
|
self._building.discard(hunt_id)
|
||||||
|
|
||||||
|
def status(self, hunt_id: str) -> str:
|
||||||
|
if hunt_id in self._building:
|
||||||
|
return "building"
|
||||||
|
if hunt_id in self._data:
|
||||||
|
return "ready"
|
||||||
|
return "none"
|
||||||
|
|
||||||
|
|
||||||
|
inventory_cache = _InventoryCache()
|
||||||
|
|
||||||
def _infer_os(fqdn: str) -> str:
|
def _infer_os(fqdn: str) -> str:
|
||||||
u = fqdn.upper()
|
u = fqdn.upper()
|
||||||
if 'W10-' in u or 'WIN10' in u:
|
if 'W10-' in u or 'WIN10' in u:
|
||||||
@@ -155,29 +205,57 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
|||||||
connections: dict[tuple, int] = defaultdict(int)
|
connections: dict[tuple, int] = defaultdict(int)
|
||||||
total_rows = 0
|
total_rows = 0
|
||||||
ds_with_hosts = 0
|
ds_with_hosts = 0
|
||||||
|
sampled_dataset_count = 0
|
||||||
|
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||||
|
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||||
|
global_budget_reached = False
|
||||||
|
dropped_connections = 0
|
||||||
|
|
||||||
for ds in all_datasets:
|
for ds in all_datasets:
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
cols = _identify_columns(ds)
|
cols = _identify_columns(ds)
|
||||||
if not cols['fqdn'] and not cols['host_id']:
|
if not cols['fqdn'] and not cols['host_id']:
|
||||||
continue
|
continue
|
||||||
ds_with_hosts += 1
|
ds_with_hosts += 1
|
||||||
|
|
||||||
batch_size = 5000
|
batch_size = 5000
|
||||||
offset = 0
|
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||||
|
rows_scanned_this_dataset = 0
|
||||||
|
sampled_dataset = False
|
||||||
|
last_row_index = -1
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
sampled_dataset = True
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
rr = await db.execute(
|
rr = await db.execute(
|
||||||
select(DatasetRow)
|
select(DatasetRow)
|
||||||
.where(DatasetRow.dataset_id == ds.id)
|
.where(DatasetRow.dataset_id == ds.id)
|
||||||
|
.where(DatasetRow.row_index > last_row_index)
|
||||||
.order_by(DatasetRow.row_index)
|
.order_by(DatasetRow.row_index)
|
||||||
.offset(offset).limit(batch_size)
|
.limit(batch_size)
|
||||||
)
|
)
|
||||||
rows = rr.scalars().all()
|
rows = rr.scalars().all()
|
||||||
if not rows:
|
if not rows:
|
||||||
break
|
break
|
||||||
|
|
||||||
for ro in rows:
|
for ro in rows:
|
||||||
|
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||||
|
sampled_dataset = True
|
||||||
|
break
|
||||||
|
if total_row_budget and total_rows >= total_row_budget:
|
||||||
|
sampled_dataset = True
|
||||||
|
global_budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
data = ro.data or {}
|
data = ro.data or {}
|
||||||
total_rows += 1
|
total_rows += 1
|
||||||
|
rows_scanned_this_dataset += 1
|
||||||
|
|
||||||
fqdn = ''
|
fqdn = ''
|
||||||
for c in cols['fqdn']:
|
for c in cols['fqdn']:
|
||||||
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
|||||||
rport = _clean(data.get(pc))
|
rport = _clean(data.get(pc))
|
||||||
if rport:
|
if rport:
|
||||||
break
|
break
|
||||||
connections[(host_key, rip, rport)] += 1
|
conn_key = (host_key, rip, rport)
|
||||||
|
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||||
|
dropped_connections += 1
|
||||||
|
continue
|
||||||
|
connections[conn_key] += 1
|
||||||
|
|
||||||
offset += batch_size
|
if sampled_dataset:
|
||||||
|
sampled_dataset_count += 1
|
||||||
|
logger.info(
|
||||||
|
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||||
|
ds.id,
|
||||||
|
rows_scanned_this_dataset,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
last_row_index = rows[-1].row_index
|
||||||
if len(rows) < batch_size:
|
if len(rows) < batch_size:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if global_budget_reached:
|
||||||
|
logger.info(
|
||||||
|
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||||
|
hunt_id,
|
||||||
|
total_rows,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
# Post-process hosts
|
# Post-process hosts
|
||||||
for h in hosts.values():
|
for h in hosts.values():
|
||||||
if not h['os'] and h['fqdn']:
|
if not h['os'] and h['fqdn']:
|
||||||
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
|||||||
"total_rows_scanned": total_rows,
|
"total_rows_scanned": total_rows,
|
||||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||||
|
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||||
|
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||||
|
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||||
|
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||||
|
"sampled_datasets": sampled_dataset_count,
|
||||||
|
"global_budget_reached": global_budget_reached,
|
||||||
|
"dropped_connections": dropped_connections,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
|
|||||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||||
|
|
||||||
|
# Velociraptor client IDs (C.hex) are not real hostnames
|
||||||
|
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
|
||||||
|
|
||||||
|
|
||||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
@@ -154,7 +158,7 @@ async def profile_host(
|
|||||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
logger.error("Failed to profile host %s: %r", hostname, e)
|
||||||
profile = HostProfile(
|
profile = HostProfile(
|
||||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||||
risk_score=0.0, risk_level="unknown",
|
risk_score=0.0, risk_level="unknown",
|
||||||
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
|
|||||||
if h not in hostnames:
|
if h not in hostnames:
|
||||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||||
|
|
||||||
|
# Filter out Velociraptor client IDs - not real hostnames
|
||||||
|
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
|
||||||
|
skipped = len(hostnames) - len(real_hosts)
|
||||||
|
if skipped:
|
||||||
|
logger.info("Skipped %d Velociraptor client IDs", skipped)
|
||||||
|
hostnames = real_hosts
|
||||||
|
|
||||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Async job queue for background AI tasks.
|
"""Async job queue for background AI tasks.
|
||||||
|
|
||||||
Manages triage, profiling, report generation, anomaly detection,
|
Manages triage, profiling, report generation, anomaly detection,
|
||||||
and data queries as trackable jobs with status, progress, and
|
keyword scanning, IOC extraction, and data queries as trackable
|
||||||
cancellation support.
|
jobs with status, progress, and cancellation support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Coroutine, Optional
|
from typing import Any, Callable, Coroutine, Optional
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +34,18 @@ class JobType(str, Enum):
|
|||||||
REPORT = "report"
|
REPORT = "report"
|
||||||
ANOMALY = "anomaly"
|
ANOMALY = "anomaly"
|
||||||
QUERY = "query"
|
QUERY = "query"
|
||||||
|
HOST_INVENTORY = "host_inventory"
|
||||||
|
KEYWORD_SCAN = "keyword_scan"
|
||||||
|
IOC_EXTRACT = "ioc_extract"
|
||||||
|
|
||||||
|
|
||||||
|
# Job types that form the automatic upload pipeline
|
||||||
|
PIPELINE_JOB_TYPES = frozenset({
|
||||||
|
JobType.TRIAGE,
|
||||||
|
JobType.ANOMALY,
|
||||||
|
JobType.KEYWORD_SCAN,
|
||||||
|
JobType.IOC_EXTRACT,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -82,11 +96,7 @@ class Job:
|
|||||||
|
|
||||||
|
|
||||||
class JobQueue:
|
class JobQueue:
|
||||||
"""In-memory async job queue with concurrency control.
|
"""In-memory async job queue with concurrency control."""
|
||||||
|
|
||||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
|
||||||
A configurable number of workers process jobs from the queue.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_workers: int = 3):
|
def __init__(self, max_workers: int = 3):
|
||||||
self._jobs: dict[str, Job] = {}
|
self._jobs: dict[str, Job] = {}
|
||||||
@@ -95,47 +105,56 @@ class JobQueue:
|
|||||||
self._workers: list[asyncio.Task] = []
|
self._workers: list[asyncio.Task] = []
|
||||||
self._handlers: dict[JobType, Callable] = {}
|
self._handlers: dict[JobType, Callable] = {}
|
||||||
self._started = False
|
self._started = False
|
||||||
|
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
|
||||||
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
|
|
||||||
def register_handler(
|
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
|
||||||
self,
|
|
||||||
job_type: JobType,
|
|
||||||
handler: Callable[[Job], Coroutine],
|
|
||||||
):
|
|
||||||
"""Register an async handler for a job type.
|
|
||||||
|
|
||||||
Handler signature: async def handler(job: Job) -> Any
|
|
||||||
The handler can update job.progress and job.message during execution.
|
|
||||||
It should check job.is_cancelled periodically and return early.
|
|
||||||
"""
|
|
||||||
self._handlers[job_type] = handler
|
self._handlers[job_type] = handler
|
||||||
logger.info(f"Registered handler for {job_type.value}")
|
logger.info(f"Registered handler for {job_type.value}")
|
||||||
|
|
||||||
|
def on_completion(self, callback: Callable[[Job], Coroutine]):
|
||||||
|
"""Register a callback invoked after any job completes or fails."""
|
||||||
|
self._completion_callbacks.append(callback)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Start worker tasks."""
|
|
||||||
if self._started:
|
if self._started:
|
||||||
return
|
return
|
||||||
self._started = True
|
self._started = True
|
||||||
for i in range(self._max_workers):
|
for i in range(self._max_workers):
|
||||||
task = asyncio.create_task(self._worker(i))
|
task = asyncio.create_task(self._worker(i))
|
||||||
self._workers.append(task)
|
self._workers.append(task)
|
||||||
|
if not self._cleanup_task or self._cleanup_task.done():
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stop all workers."""
|
|
||||||
self._started = False
|
self._started = False
|
||||||
for w in self._workers:
|
for w in self._workers:
|
||||||
w.cancel()
|
w.cancel()
|
||||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||||
self._workers.clear()
|
self._workers.clear()
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||||
|
self._cleanup_task = None
|
||||||
logger.info("Job queue stopped")
|
logger.info("Job queue stopped")
|
||||||
|
|
||||||
def submit(self, job_type: JobType, **params) -> Job:
|
def submit(self, job_type: JobType, **params) -> Job:
|
||||||
"""Submit a new job. Returns the Job object immediately."""
|
# Soft backpressure: prefer dedupe over queue amplification
|
||||||
job = Job(
|
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||||
id=str(uuid.uuid4()),
|
if dedupe_job is not None:
|
||||||
job_type=job_type,
|
logger.info(
|
||||||
params=params,
|
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||||
)
|
)
|
||||||
|
return dedupe_job
|
||||||
|
|
||||||
|
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||||
|
logger.warning(
|
||||||
|
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||||
|
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||||
|
)
|
||||||
|
|
||||||
|
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||||
self._jobs[job.id] = job
|
self._jobs[job.id] = job
|
||||||
self._queue.put_nowait(job.id)
|
self._queue.put_nowait(job.id)
|
||||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||||
@@ -144,6 +163,22 @@ class JobQueue:
|
|||||||
def get_job(self, job_id: str) -> Job | None:
|
def get_job(self, job_id: str) -> Job | None:
|
||||||
return self._jobs.get(job_id)
|
return self._jobs.get(job_id)
|
||||||
|
|
||||||
|
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||||
|
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||||
|
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||||
|
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||||
|
if not sig:
|
||||||
|
return None
|
||||||
|
for j in self._jobs.values():
|
||||||
|
if j.job_type != job_type:
|
||||||
|
continue
|
||||||
|
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||||
|
continue
|
||||||
|
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||||
|
if sig == other_sig:
|
||||||
|
return j
|
||||||
|
return None
|
||||||
|
|
||||||
def cancel_job(self, job_id: str) -> bool:
|
def cancel_job(self, job_id: str) -> bool:
|
||||||
job = self._jobs.get(job_id)
|
job = self._jobs.get(job_id)
|
||||||
if not job:
|
if not job:
|
||||||
@@ -153,13 +188,7 @@ class JobQueue:
|
|||||||
job.cancel()
|
job.cancel()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def list_jobs(
|
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
|
||||||
self,
|
|
||||||
status: JobStatus | None = None,
|
|
||||||
job_type: JobType | None = None,
|
|
||||||
limit: int = 50,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""List jobs, newest first."""
|
|
||||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||||
if status:
|
if status:
|
||||||
jobs = [j for j in jobs if j.status == status]
|
jobs = [j for j in jobs if j.status == status]
|
||||||
@@ -168,7 +197,6 @@ class JobQueue:
|
|||||||
return [j.to_dict() for j in jobs[:limit]]
|
return [j.to_dict() for j in jobs[:limit]]
|
||||||
|
|
||||||
def get_stats(self) -> dict:
|
def get_stats(self) -> dict:
|
||||||
"""Get queue statistics."""
|
|
||||||
by_status = {}
|
by_status = {}
|
||||||
for j in self._jobs.values():
|
for j in self._jobs.values():
|
||||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||||
@@ -177,26 +205,58 @@ class JobQueue:
|
|||||||
"queued": self._queue.qsize(),
|
"queued": self._queue.qsize(),
|
||||||
"by_status": by_status,
|
"by_status": by_status,
|
||||||
"workers": self._max_workers,
|
"workers": self._max_workers,
|
||||||
"active_workers": sum(
|
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
|
||||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def is_backlogged(self) -> bool:
|
||||||
|
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
|
||||||
|
|
||||||
|
def can_accept(self, reserve: int = 0) -> bool:
|
||||||
|
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
|
||||||
|
|
||||||
def cleanup(self, max_age_seconds: float = 3600):
|
def cleanup(self, max_age_seconds: float = 3600):
|
||||||
"""Remove old completed/failed/cancelled jobs."""
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||||
to_remove = [
|
to_remove = [
|
||||||
jid for jid, j in self._jobs.items()
|
jid for jid, j in self._jobs.items()
|
||||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||||
and (now - j.created_at) > max_age_seconds
|
|
||||||
]
|
]
|
||||||
for jid in to_remove:
|
|
||||||
|
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||||
|
terminal_jobs = sorted(
|
||||||
|
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||||
|
key=lambda j: j.created_at,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||||
|
to_remove.extend([j.id for j in overflow])
|
||||||
|
|
||||||
|
removed = 0
|
||||||
|
for jid in set(to_remove):
|
||||||
|
if jid in self._jobs:
|
||||||
del self._jobs[jid]
|
del self._jobs[jid]
|
||||||
if to_remove:
|
removed += 1
|
||||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
if removed:
|
||||||
|
logger.info(f"Cleaned up {removed} old jobs")
|
||||||
|
|
||||||
|
async def _cleanup_loop(self):
|
||||||
|
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||||
|
while self._started:
|
||||||
|
try:
|
||||||
|
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
|
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
|
||||||
|
"""Find all pipeline jobs for a given dataset_id."""
|
||||||
|
return [
|
||||||
|
j for j in self._jobs.values()
|
||||||
|
if j.job_type in PIPELINE_JOB_TYPES
|
||||||
|
and j.params.get("dataset_id") == dataset_id
|
||||||
|
]
|
||||||
|
|
||||||
async def _worker(self, worker_id: int):
|
async def _worker(self, worker_id: int):
|
||||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
|
||||||
logger.info(f"Worker {worker_id} started")
|
logger.info(f"Worker {worker_id} started")
|
||||||
while self._started:
|
while self._started:
|
||||||
try:
|
try:
|
||||||
@@ -220,7 +280,10 @@ class JobQueue:
|
|||||||
|
|
||||||
job.status = JobStatus.RUNNING
|
job.status = JobStatus.RUNNING
|
||||||
job.started_at = time.time()
|
job.started_at = time.time()
|
||||||
|
if job.progress <= 0:
|
||||||
|
job.progress = 5.0
|
||||||
job.message = "Running..."
|
job.message = "Running..."
|
||||||
|
await _sync_processing_task(job)
|
||||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -231,38 +294,111 @@ class JobQueue:
|
|||||||
job.result = result
|
job.result = result
|
||||||
job.message = "Completed"
|
job.message = "Completed"
|
||||||
job.completed_at = time.time()
|
job.completed_at = time.time()
|
||||||
logger.info(
|
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||||
f"Worker {worker_id}: completed {job.id} "
|
|
||||||
f"in {job.elapsed_ms}ms"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not job.is_cancelled:
|
if not job.is_cancelled:
|
||||||
job.status = JobStatus.FAILED
|
job.status = JobStatus.FAILED
|
||||||
job.error = str(e)
|
job.error = str(e)
|
||||||
job.message = f"Failed: {e}"
|
job.message = f"Failed: {e}"
|
||||||
job.completed_at = time.time()
|
job.completed_at = time.time()
|
||||||
logger.error(
|
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
|
||||||
exc_info=True,
|
if job.is_cancelled and not job.completed_at:
|
||||||
|
job.completed_at = time.time()
|
||||||
|
|
||||||
|
await _sync_processing_task(job)
|
||||||
|
|
||||||
|
# Fire completion callbacks
|
||||||
|
for cb in self._completion_callbacks:
|
||||||
|
try:
|
||||||
|
await cb(job)
|
||||||
|
except Exception as cb_err:
|
||||||
|
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_processing_task(job: Job):
|
||||||
|
"""Persist latest job state into processing_tasks (if linked by job_id)."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.db.models import ProcessingTask
|
||||||
|
|
||||||
|
values = {
|
||||||
|
"status": job.status.value,
|
||||||
|
"progress": float(job.progress),
|
||||||
|
"message": job.message,
|
||||||
|
"error": job.error,
|
||||||
|
}
|
||||||
|
if job.started_at:
|
||||||
|
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
|
||||||
|
if job.completed_at:
|
||||||
|
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
await db.execute(
|
||||||
|
update(ProcessingTask)
|
||||||
|
.where(ProcessingTask.job_id == job.id)
|
||||||
|
.values(**values)
|
||||||
)
|
)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Singleton + job handlers
|
# -- Singleton + job handlers --
|
||||||
|
|
||||||
job_queue = JobQueue(max_workers=3)
|
job_queue = JobQueue(max_workers=5)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_triage(job: Job):
|
async def _handle_triage(job: Job):
|
||||||
"""Triage handler."""
|
"""Triage handler - chains HOST_PROFILE after completion."""
|
||||||
from app.services.triage import triage_dataset
|
from app.services.triage import triage_dataset
|
||||||
dataset_id = job.params.get("dataset_id")
|
dataset_id = job.params.get("dataset_id")
|
||||||
job.message = f"Triaging dataset {dataset_id}"
|
job.message = f"Triaging dataset {dataset_id}"
|
||||||
results = await triage_dataset(dataset_id)
|
await triage_dataset(dataset_id)
|
||||||
return {"count": len(results) if results else 0}
|
|
||||||
|
# Chain: trigger host profiling now that triage results exist
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.db.models import Dataset
|
||||||
|
from sqlalchemy import select
|
||||||
|
try:
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
|
||||||
|
row = ds.first()
|
||||||
|
hunt_id = row[0] if row else None
|
||||||
|
if hunt_id:
|
||||||
|
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.db.models import ProcessingTask
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
existing = await db.execute(
|
||||||
|
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||||
|
)
|
||||||
|
if existing.first() is None:
|
||||||
|
db.add(ProcessingTask(
|
||||||
|
hunt_id=hunt_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
job_id=hp_job.id,
|
||||||
|
stage="host_profile",
|
||||||
|
status="queued",
|
||||||
|
progress=0.0,
|
||||||
|
message="Queued",
|
||||||
|
))
|
||||||
|
await db.commit()
|
||||||
|
except Exception as persist_err:
|
||||||
|
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||||
|
|
||||||
|
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to chain host profile after triage: {e}")
|
||||||
|
|
||||||
|
return {"dataset_id": dataset_id}
|
||||||
|
|
||||||
|
|
||||||
async def _handle_host_profile(job: Job):
|
async def _handle_host_profile(job: Job):
|
||||||
"""Host profiling handler."""
|
|
||||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||||
hunt_id = job.params.get("hunt_id")
|
hunt_id = job.params.get("hunt_id")
|
||||||
hostname = job.params.get("hostname")
|
hostname = job.params.get("hostname")
|
||||||
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_report(job: Job):
|
async def _handle_report(job: Job):
|
||||||
"""Report generation handler."""
|
|
||||||
from app.services.report_generator import generate_report
|
from app.services.report_generator import generate_report
|
||||||
hunt_id = job.params.get("hunt_id")
|
hunt_id = job.params.get("hunt_id")
|
||||||
job.message = f"Generating report for hunt {hunt_id}"
|
job.message = f"Generating report for hunt {hunt_id}"
|
||||||
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_anomaly(job: Job):
|
async def _handle_anomaly(job: Job):
|
||||||
"""Anomaly detection handler."""
|
|
||||||
from app.services.anomaly_detector import detect_anomalies
|
from app.services.anomaly_detector import detect_anomalies
|
||||||
dataset_id = job.params.get("dataset_id")
|
dataset_id = job.params.get("dataset_id")
|
||||||
k = job.params.get("k", 3)
|
k = job.params.get("k", 3)
|
||||||
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_query(job: Job):
|
async def _handle_query(job: Job):
|
||||||
"""Data query handler (non-streaming)."""
|
|
||||||
from app.services.data_query import query_dataset
|
from app.services.data_query import query_dataset
|
||||||
dataset_id = job.params.get("dataset_id")
|
dataset_id = job.params.get("dataset_id")
|
||||||
question = job.params.get("question", "")
|
question = job.params.get("question", "")
|
||||||
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
|
|||||||
return {"answer": answer}
|
return {"answer": answer}
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_host_inventory(job: Job):
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||||
|
|
||||||
|
hunt_id = job.params.get("hunt_id")
|
||||||
|
if not hunt_id:
|
||||||
|
raise ValueError("hunt_id required")
|
||||||
|
|
||||||
|
inventory_cache.set_building(hunt_id)
|
||||||
|
job.message = f"Building host inventory for hunt {hunt_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
result = await build_host_inventory(hunt_id, db)
|
||||||
|
inventory_cache.put(hunt_id, result)
|
||||||
|
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
|
||||||
|
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
|
||||||
|
except Exception:
|
||||||
|
inventory_cache.clear_building(hunt_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_keyword_scan(job: Job):
|
||||||
|
"""AUP keyword scan handler."""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||||
|
|
||||||
|
dataset_id = job.params.get("dataset_id")
|
||||||
|
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
|
||||||
|
# Cache dataset-only result for fast API reuse
|
||||||
|
if dataset_id:
|
||||||
|
keyword_scan_cache.put(dataset_id, result)
|
||||||
|
|
||||||
|
hits = result.get("total_hits", 0)
|
||||||
|
job.message = f"Keyword scan complete: {hits} hits"
|
||||||
|
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||||
|
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_ioc_extract(job: Job):
|
||||||
|
"""IOC extraction handler."""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||||
|
|
||||||
|
dataset_id = job.params.get("dataset_id")
|
||||||
|
job.message = f"Extracting IOCs from dataset {dataset_id}"
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
iocs = await extract_iocs_from_dataset(dataset_id, db)
|
||||||
|
|
||||||
|
total = sum(len(v) for v in iocs.values())
|
||||||
|
job.message = f"IOC extraction complete: {total} IOCs found"
|
||||||
|
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
|
||||||
|
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _on_pipeline_job_complete(job: Job):
|
||||||
|
"""Update Dataset.processing_status when all pipeline jobs finish."""
|
||||||
|
if job.job_type not in PIPELINE_JOB_TYPES:
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset_id = job.params.get("dataset_id")
|
||||||
|
if not dataset_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
|
||||||
|
if not pipeline_jobs:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_done = all(
|
||||||
|
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||||
|
for j in pipeline_jobs
|
||||||
|
)
|
||||||
|
if not all_done:
|
||||||
|
return
|
||||||
|
|
||||||
|
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
|
||||||
|
new_status = "completed_with_errors" if any_failed else "completed"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.db.models import Dataset
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
await db.execute(
|
||||||
|
update(Dataset)
|
||||||
|
.where(Dataset.id == dataset_id)
|
||||||
|
.values(processing_status=new_status)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def reconcile_stale_processing_tasks() -> int:
|
||||||
|
"""Mark queued/running processing tasks from prior runs as failed."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.db import async_session_factory
|
||||||
|
from app.db.models import ProcessingTask
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
async with async_session_factory() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
update(ProcessingTask)
|
||||||
|
.where(ProcessingTask.status.in_(["queued", "running"]))
|
||||||
|
.values(
|
||||||
|
status="failed",
|
||||||
|
error="Recovered after service restart before task completion",
|
||||||
|
message="Recovered stale task after restart",
|
||||||
|
completed_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
updated = int(result.rowcount or 0)
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
logger.warning(
|
||||||
|
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
|
||||||
|
updated,
|
||||||
|
)
|
||||||
|
return updated
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def register_all_handlers():
|
def register_all_handlers():
|
||||||
"""Register all job handlers."""
|
"""Register all job handlers and completion callbacks."""
|
||||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||||
|
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
|
||||||
|
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
|
||||||
|
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
|
||||||
|
job_queue.on_completion(_on_pipeline_job_complete)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||||
messages for keyword matches.
|
messages for keyword matches.
|
||||||
|
|
||||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||||
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
from app.db.models import (
|
from app.db.models import (
|
||||||
KeywordTheme,
|
KeywordTheme,
|
||||||
Keyword,
|
|
||||||
DatasetRow,
|
DatasetRow,
|
||||||
Dataset,
|
Dataset,
|
||||||
Hunt,
|
Hunt,
|
||||||
Annotation,
|
Annotation,
|
||||||
Message,
|
Message,
|
||||||
Conversation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BATCH_SIZE = 500
|
BATCH_SIZE = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||||
|
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||||
|
if not data:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
host_keys = (
|
||||||
|
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||||
|
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||||
|
)
|
||||||
|
user_keys = (
|
||||||
|
'username', 'user_name', 'user', 'account_name',
|
||||||
|
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||||
|
)
|
||||||
|
|
||||||
|
def pick(keys):
|
||||||
|
for k in keys:
|
||||||
|
for actual_key, v in data.items():
|
||||||
|
if actual_key.lower() == k and v not in (None, ''):
|
||||||
|
return str(v)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return pick(host_keys), pick(user_keys)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -39,6 +64,8 @@ class ScanHit:
|
|||||||
matched_value: str
|
matched_value: str
|
||||||
row_index: int | None = None
|
row_index: int | None = None
|
||||||
dataset_name: str | None = None
|
dataset_name: str | None = None
|
||||||
|
hostname: str | None = None
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -50,21 +77,54 @@ class ScanResult:
|
|||||||
rows_scanned: int = 0
|
rows_scanned: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KeywordScanCacheEntry:
|
||||||
|
dataset_id: str
|
||||||
|
result: dict
|
||||||
|
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordScanCache:
|
||||||
|
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||||
|
|
||||||
|
This enables fast-path reads when users run AUP scans against datasets that
|
||||||
|
were already scanned during upload pipeline processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||||
|
|
||||||
|
def put(self, dataset_id: str, result: dict):
|
||||||
|
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||||
|
|
||||||
|
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||||
|
return self._entries.get(dataset_id)
|
||||||
|
|
||||||
|
def invalidate_dataset(self, dataset_id: str):
|
||||||
|
self._entries.pop(dataset_id, None)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._entries.clear()
|
||||||
|
|
||||||
|
|
||||||
|
keyword_scan_cache = KeywordScanCache()
|
||||||
|
|
||||||
|
|
||||||
class KeywordScanner:
|
class KeywordScanner:
|
||||||
"""Scans multiple data sources for keyword/regex matches."""
|
"""Scans multiple data sources for keyword/regex matches."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession):
|
def __init__(self, db: AsyncSession):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────
|
# Public API
|
||||||
|
|
||||||
async def scan(
|
async def scan(
|
||||||
self,
|
self,
|
||||||
dataset_ids: list[str] | None = None,
|
dataset_ids: list[str] | None = None,
|
||||||
theme_ids: list[str] | None = None,
|
theme_ids: list[str] | None = None,
|
||||||
scan_hunts: bool = True,
|
scan_hunts: bool = False,
|
||||||
scan_annotations: bool = True,
|
scan_annotations: bool = False,
|
||||||
scan_messages: bool = True,
|
scan_messages: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||||
# Load themes + keywords
|
# Load themes + keywords
|
||||||
@@ -103,7 +163,7 @@ class KeywordScanner:
|
|||||||
"rows_scanned": result.rows_scanned,
|
"rows_scanned": result.rows_scanned,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Internal ──────────────────────────────────────────────────────
|
# Internal
|
||||||
|
|
||||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||||
@@ -143,6 +203,8 @@ class KeywordScanner:
|
|||||||
hits: list[ScanHit],
|
hits: list[ScanHit],
|
||||||
row_index: int | None = None,
|
row_index: int | None = None,
|
||||||
dataset_name: str | None = None,
|
dataset_name: str | None = None,
|
||||||
|
hostname: str | None = None,
|
||||||
|
username: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check text against all compiled patterns, append hits."""
|
"""Check text against all compiled patterns, append hits."""
|
||||||
if not text:
|
if not text:
|
||||||
@@ -150,8 +212,7 @@ class KeywordScanner:
|
|||||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||||
for kw_value, pat in keyword_patterns:
|
for kw_value, pat in keyword_patterns:
|
||||||
if pat.search(text):
|
if pat.search(text):
|
||||||
# Truncate matched_value for display
|
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
|
||||||
hits.append(ScanHit(
|
hits.append(ScanHit(
|
||||||
theme_name=theme_name,
|
theme_name=theme_name,
|
||||||
theme_color=theme_color,
|
theme_color=theme_color,
|
||||||
@@ -162,13 +223,14 @@ class KeywordScanner:
|
|||||||
matched_value=matched_preview,
|
matched_value=matched_preview,
|
||||||
row_index=row_index,
|
row_index=row_index,
|
||||||
dataset_name=dataset_name,
|
dataset_name=dataset_name,
|
||||||
|
hostname=hostname,
|
||||||
|
username=username,
|
||||||
))
|
))
|
||||||
|
|
||||||
async def _scan_datasets(
|
async def _scan_datasets(
|
||||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Scan dataset rows in batches."""
|
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||||
# Build dataset name lookup
|
|
||||||
ds_q = select(Dataset.id, Dataset.name)
|
ds_q = select(Dataset.id, Dataset.name)
|
||||||
if dataset_ids:
|
if dataset_ids:
|
||||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||||
@@ -178,15 +240,27 @@ class KeywordScanner:
|
|||||||
if not ds_map:
|
if not ds_map:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Iterate rows in batches
|
import asyncio
|
||||||
offset = 0
|
|
||||||
row_q_base = select(DatasetRow).where(
|
|
||||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
|
||||||
).order_by(DatasetRow.id)
|
|
||||||
|
|
||||||
|
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||||
|
budget_reached = False
|
||||||
|
|
||||||
|
for ds_id, ds_name in ds_map.items():
|
||||||
|
if max_rows and result.rows_scanned >= max_rows:
|
||||||
|
budget_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
last_id = 0
|
||||||
while True:
|
while True:
|
||||||
|
if max_rows and result.rows_scanned >= max_rows:
|
||||||
|
budget_reached = True
|
||||||
|
break
|
||||||
rows_result = await self.db.execute(
|
rows_result = await self.db.execute(
|
||||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
select(DatasetRow)
|
||||||
|
.where(DatasetRow.dataset_id == ds_id)
|
||||||
|
.where(DatasetRow.id > last_id)
|
||||||
|
.order_by(DatasetRow.id)
|
||||||
|
.limit(BATCH_SIZE)
|
||||||
)
|
)
|
||||||
rows = rows_result.scalars().all()
|
rows = rows_result.scalars().all()
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -195,21 +269,38 @@ class KeywordScanner:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
result.rows_scanned += 1
|
result.rows_scanned += 1
|
||||||
data = row.data or {}
|
data = row.data or {}
|
||||||
|
hostname, username = _infer_hostname_and_user(data)
|
||||||
for col_name, cell_value in data.items():
|
for col_name, cell_value in data.items():
|
||||||
if cell_value is None:
|
if cell_value is None:
|
||||||
continue
|
continue
|
||||||
text = str(cell_value)
|
text = str(cell_value)
|
||||||
self._match_text(
|
self._match_text(
|
||||||
text, patterns, "dataset_row", row.id,
|
text,
|
||||||
col_name, result.hits,
|
patterns,
|
||||||
|
"dataset_row",
|
||||||
|
row.id,
|
||||||
|
col_name,
|
||||||
|
result.hits,
|
||||||
row_index=row.row_index,
|
row_index=row.row_index,
|
||||||
dataset_name=ds_map.get(row.dataset_id),
|
dataset_name=ds_name,
|
||||||
|
hostname=hostname,
|
||||||
|
username=username,
|
||||||
)
|
)
|
||||||
|
|
||||||
offset += BATCH_SIZE
|
last_id = rows[-1].id
|
||||||
|
await asyncio.sleep(0)
|
||||||
if len(rows) < BATCH_SIZE:
|
if len(rows) < BATCH_SIZE:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if budget_reached:
|
||||||
|
break
|
||||||
|
|
||||||
|
if budget_reached:
|
||||||
|
logger.warning(
|
||||||
|
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||||
|
result.rows_scanned,
|
||||||
|
)
|
||||||
|
|
||||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||||
"""Scan hunt names and descriptions."""
|
"""Scan hunt names and descriptions."""
|
||||||
hunts_result = await self.db.execute(select(Hunt))
|
hunts_result = await self.db.execute(select(Hunt))
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
|
||||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||||
|
|
||||||
ARTIFACT_FOCUS = {
|
ARTIFACT_FOCUS = {
|
||||||
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
|
|||||||
rows_result = await db.execute(
|
rows_result = await db.execute(
|
||||||
select(DatasetRow)
|
select(DatasetRow)
|
||||||
.where(DatasetRow.dataset_id == dataset_id)
|
.where(DatasetRow.dataset_id == dataset_id)
|
||||||
.order_by(DatasetRow.row_number)
|
.order_by(DatasetRow.row_index)
|
||||||
.offset(offset)
|
.offset(offset)
|
||||||
.limit(batch_size)
|
.limit(batch_size)
|
||||||
)
|
)
|
||||||
|
|||||||
124
backend/tests/test_agent_policy_execution.py
Normal file
124
backend/tests/test_agent_policy_execution.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Tests for execution-mode behavior in /api/agent/assist."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_assist_policy_query_executes_scan(client):
|
||||||
|
# 1) Create hunt
|
||||||
|
h = await client.post("/api/hunts", json={"name": "Policy Hunt"})
|
||||||
|
assert h.status_code == 200
|
||||||
|
hunt_id = h.json()["id"]
|
||||||
|
|
||||||
|
# 2) Upload browser-history-like CSV
|
||||||
|
csv_bytes = (
|
||||||
|
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||||
|
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||||
|
b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n"
|
||||||
|
)
|
||||||
|
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||||
|
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||||
|
assert up.status_code == 200
|
||||||
|
|
||||||
|
# 3) Ensure policy theme/keyword exists
|
||||||
|
t = await client.post(
|
||||||
|
"/api/keywords/themes",
|
||||||
|
json={
|
||||||
|
"name": "Adult Content",
|
||||||
|
"color": "#e91e63",
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert t.status_code in (201, 409)
|
||||||
|
|
||||||
|
themes = await client.get("/api/keywords/themes")
|
||||||
|
assert themes.status_code == 200
|
||||||
|
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||||
|
|
||||||
|
k = await client.post(
|
||||||
|
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||||
|
json={"value": "pornhub", "is_regex": False},
|
||||||
|
)
|
||||||
|
assert k.status_code in (201, 409)
|
||||||
|
|
||||||
|
# 4) Execution-mode query
|
||||||
|
q = await client.post(
|
||||||
|
"/api/agent/assist",
|
||||||
|
json={
|
||||||
|
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||||
|
"hunt_id": hunt_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert q.status_code == 200
|
||||||
|
body = q.json()
|
||||||
|
|
||||||
|
assert body["model_used"] == "execution:keyword_scanner"
|
||||||
|
assert body["execution"] is not None
|
||||||
|
assert body["execution"]["policy_hits"] >= 1
|
||||||
|
assert len(body["execution"]["top_user_hosts"]) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_assist_execution_preference_off_stays_advisory(client):
|
||||||
|
h = await client.post("/api/hunts", json={"name": "No Exec Hunt"})
|
||||||
|
assert h.status_code == 200
|
||||||
|
hunt_id = h.json()["id"]
|
||||||
|
|
||||||
|
q = await client.post(
|
||||||
|
"/api/agent/assist",
|
||||||
|
json={
|
||||||
|
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||||
|
"hunt_id": hunt_id,
|
||||||
|
"execution_preference": "off",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert q.status_code == 200
|
||||||
|
body = q.json()
|
||||||
|
assert body["model_used"] != "execution:keyword_scanner"
|
||||||
|
assert body["execution"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_assist_execution_preference_force_executes(client):
|
||||||
|
# Create hunt + dataset even when the query text is not policy-specific
|
||||||
|
h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"})
|
||||||
|
assert h.status_code == 200
|
||||||
|
hunt_id = h.json()["id"]
|
||||||
|
|
||||||
|
csv_bytes = (
|
||||||
|
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||||
|
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||||
|
)
|
||||||
|
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||||
|
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||||
|
assert up.status_code == 200
|
||||||
|
|
||||||
|
t = await client.post(
|
||||||
|
"/api/keywords/themes",
|
||||||
|
json={"name": "Adult Content", "color": "#e91e63", "enabled": True},
|
||||||
|
)
|
||||||
|
assert t.status_code in (201, 409)
|
||||||
|
|
||||||
|
themes = await client.get("/api/keywords/themes")
|
||||||
|
assert themes.status_code == 200
|
||||||
|
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||||
|
k = await client.post(
|
||||||
|
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||||
|
json={"value": "pornhub", "is_regex": False},
|
||||||
|
)
|
||||||
|
assert k.status_code in (201, 409)
|
||||||
|
|
||||||
|
q = await client.post(
|
||||||
|
"/api/agent/assist",
|
||||||
|
json={
|
||||||
|
"query": "Summarize notable activity in this hunt.",
|
||||||
|
"hunt_id": hunt_id,
|
||||||
|
"execution_preference": "force",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert q.status_code == 200
|
||||||
|
body = q.json()
|
||||||
|
assert body["model_used"] == "execution:keyword_scanner"
|
||||||
|
assert body["execution"] is not None
|
||||||
@@ -77,6 +77,26 @@ class TestHuntEndpoints:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hunt_progress(self, client):
|
||||||
|
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
|
||||||
|
hunt_id = create.json()["id"]
|
||||||
|
|
||||||
|
# attach one dataset so progress has scope
|
||||||
|
from tests.conftest import SAMPLE_CSV
|
||||||
|
import io
|
||||||
|
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||||
|
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||||
|
assert up.status_code == 200
|
||||||
|
|
||||||
|
res = await client.get(f"/api/hunts/{hunt_id}/progress")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert body["hunt_id"] == hunt_id
|
||||||
|
assert "progress_percent" in body
|
||||||
|
assert "dataset_total" in body
|
||||||
|
assert "network_status" in body
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestDatasetEndpoints:
|
class TestDatasetEndpoints:
|
||||||
"""Test dataset upload and retrieval."""
|
"""Test dataset upload and retrieval."""
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Tests for CSV parser and normalizer services."""
|
"""Tests for CSV parser and normalizer services."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
||||||
@@ -43,8 +43,9 @@ class TestCSVParser:
|
|||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
def test_parse_empty_file(self):
|
def test_parse_empty_file(self):
|
||||||
with pytest.raises(Exception):
|
rows, meta = parse_csv_bytes(b"")
|
||||||
parse_csv_bytes(b"")
|
assert len(rows) == 0
|
||||||
|
assert meta["row_count"] == 0
|
||||||
|
|
||||||
def test_detect_encoding_utf8(self):
|
def test_detect_encoding_utf8(self):
|
||||||
enc = detect_encoding(SAMPLE_CSV)
|
enc = detect_encoding(SAMPLE_CSV)
|
||||||
@@ -53,17 +54,15 @@ class TestCSVParser:
|
|||||||
|
|
||||||
def test_infer_column_types(self):
|
def test_infer_column_types(self):
|
||||||
types = infer_column_types(
|
types = infer_column_types(
|
||||||
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
|
[{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}],
|
||||||
"src_ip",
|
|
||||||
)
|
)
|
||||||
assert types == "ip"
|
assert types["src_ip"] == "ip"
|
||||||
|
|
||||||
def test_infer_column_types_hash(self):
|
def test_infer_column_types_hash(self):
|
||||||
types = infer_column_types(
|
types = infer_column_types(
|
||||||
["d41d8cd98f00b204e9800998ecf8427e"],
|
[{"hash": "d41d8cd98f00b204e9800998ecf8427e"}],
|
||||||
"hash",
|
|
||||||
)
|
)
|
||||||
assert types == "hash_md5"
|
assert types["hash"] == "hash_md5"
|
||||||
|
|
||||||
|
|
||||||
class TestNormalizer:
|
class TestNormalizer:
|
||||||
@@ -94,7 +93,7 @@ class TestNormalizer:
|
|||||||
start, end = detect_time_range(rows, column_mapping)
|
start, end = detect_time_range(rows, column_mapping)
|
||||||
# Should detect time range from timestamp column
|
# Should detect time range from timestamp column
|
||||||
if start:
|
if start:
|
||||||
assert "2025" in start
|
assert "2025" in str(start)
|
||||||
|
|
||||||
def test_normalize_rows(self):
|
def test_normalize_rows(self):
|
||||||
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
||||||
@@ -102,3 +101,6 @@ class TestNormalizer:
|
|||||||
normalized = normalize_rows(rows, mapping)
|
normalized = normalize_rows(rows, mapping)
|
||||||
assert len(normalized) == 1
|
assert len(normalized) == 1
|
||||||
assert normalized[0].get("src_ip") == "10.0.0.1"
|
assert normalized[0].get("src_ip") == "10.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient):
|
|||||||
assert "total_hits" in data
|
assert "total_hits" in data
|
||||||
# powershell should match at least one row
|
# powershell should match at least one row
|
||||||
assert data["total_hits"] > 0
|
assert data["total_hits"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_quick_scan_cache_hit(client: AsyncClient):
|
||||||
|
"""Second quick scan should return cache hit metadata."""
|
||||||
|
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
|
||||||
|
tid = theme_res.json()["id"]
|
||||||
|
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||||
|
|
||||||
|
from tests.conftest import SAMPLE_CSV
|
||||||
|
import io
|
||||||
|
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||||
|
upload = await client.post("/api/datasets/upload", files=files)
|
||||||
|
ds_id = upload.json()["id"]
|
||||||
|
|
||||||
|
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||||
|
assert first.status_code == 200
|
||||||
|
assert first.json().get("cache_status") in ("miss", "hit")
|
||||||
|
|
||||||
|
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||||
|
assert second.status_code == 200
|
||||||
|
body = second.json()
|
||||||
|
assert body.get("cache_used") is True
|
||||||
|
assert body.get("cache_status") == "hit"
|
||||||
|
|||||||
84
backend/tests/test_network.py
Normal file
84
backend/tests/test_network.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for network inventory endpoints and cache/polling behavior."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
from tests.conftest import SAMPLE_CSV
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inventory_status_none_for_unknown_hunt(client):
|
||||||
|
hunt_id = "hunt-does-not-exist"
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
inventory_cache.clear_building(hunt_id)
|
||||||
|
|
||||||
|
res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert body["hunt_id"] == hunt_id
|
||||||
|
assert body["status"] == "none"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_host_inventory_cold_cache_returns_202(client):
|
||||||
|
# Create hunt and upload dataset linked to that hunt
|
||||||
|
hunt = await client.post("/api/hunts", json={"name": "Net Hunt"})
|
||||||
|
hunt_id = hunt.json()["id"]
|
||||||
|
|
||||||
|
files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||||
|
up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id})
|
||||||
|
assert up.status_code == 200
|
||||||
|
|
||||||
|
# Ensure cache is cold for this hunt
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
inventory_cache.clear_building(hunt_id)
|
||||||
|
|
||||||
|
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||||
|
assert res.status_code == 202
|
||||||
|
body = res.json()
|
||||||
|
assert body["status"] == "building"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_host_inventory_ready_cache_returns_200(client):
|
||||||
|
hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"})
|
||||||
|
hunt_id = hunt.json()["id"]
|
||||||
|
|
||||||
|
mock_inventory = {
|
||||||
|
"hosts": [
|
||||||
|
{
|
||||||
|
"id": "host-1",
|
||||||
|
"hostname": "HOST-1",
|
||||||
|
"fqdn": "HOST-1.local",
|
||||||
|
"client_id": "C.1234abcd",
|
||||||
|
"ips": ["10.0.0.10"],
|
||||||
|
"os": "Windows 10",
|
||||||
|
"users": ["alice"],
|
||||||
|
"datasets": ["test"],
|
||||||
|
"row_count": 5,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"connections": [],
|
||||||
|
"stats": {
|
||||||
|
"total_hosts": 1,
|
||||||
|
"hosts_with_ips": 1,
|
||||||
|
"hosts_with_users": 1,
|
||||||
|
"total_datasets_scanned": 1,
|
||||||
|
"total_rows_scanned": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inventory_cache.put(hunt_id, mock_inventory)
|
||||||
|
|
||||||
|
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert body["stats"]["total_hosts"] == 1
|
||||||
|
assert len(body["hosts"]) == 1
|
||||||
|
assert body["hosts"][0]["hostname"] == "HOST-1"
|
||||||
|
|
||||||
|
status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||||
|
assert status_res.status_code == 200
|
||||||
|
assert status_res.json()["status"] == "ready"
|
||||||
82
backend/tests/test_network_scale.py
Normal file
82
backend/tests/test_network_scale.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Scale-oriented network endpoint tests (summary/subgraph/backpressure)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.services.host_inventory import inventory_cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_network_summary_from_cache(client):
|
||||||
|
hunt_id = "scale-hunt-summary"
|
||||||
|
inv = {
|
||||||
|
"hosts": [
|
||||||
|
{"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50},
|
||||||
|
{"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10},
|
||||||
|
],
|
||||||
|
"connections": [
|
||||||
|
{"source": "h1", "target": "8.8.8.8", "count": 7},
|
||||||
|
{"source": "h1", "target": "h2", "count": 3},
|
||||||
|
],
|
||||||
|
"stats": {"total_hosts": 2, "total_rows_scanned": 60},
|
||||||
|
}
|
||||||
|
inventory_cache.put(hunt_id, inv)
|
||||||
|
|
||||||
|
res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert body["stats"]["total_hosts"] == 2
|
||||||
|
assert len(body["top_hosts"]) == 1
|
||||||
|
assert body["top_hosts"][0]["id"] == "h1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_network_subgraph_truncates(client):
|
||||||
|
hunt_id = "scale-hunt-subgraph"
|
||||||
|
inv = {
|
||||||
|
"hosts": [
|
||||||
|
{"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i}
|
||||||
|
for i in range(1, 8)
|
||||||
|
],
|
||||||
|
"connections": [
|
||||||
|
{"source": "h1", "target": "h2", "count": 20},
|
||||||
|
{"source": "h1", "target": "h3", "count": 15},
|
||||||
|
{"source": "h2", "target": "h4", "count": 5},
|
||||||
|
{"source": "h3", "target": "h5", "count": 4},
|
||||||
|
],
|
||||||
|
"stats": {"total_hosts": 7, "total_rows_scanned": 999},
|
||||||
|
}
|
||||||
|
inventory_cache.put(hunt_id, inv)
|
||||||
|
|
||||||
|
res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2")
|
||||||
|
assert res.status_code == 200
|
||||||
|
body = res.json()
|
||||||
|
assert len(body["hosts"]) <= 3
|
||||||
|
assert len(body["connections"]) <= 2
|
||||||
|
assert body["stats"]["truncated"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manual_job_submit_backpressure_returns_429(client):
|
||||||
|
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||||
|
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||||
|
try:
|
||||||
|
res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}})
|
||||||
|
assert res.status_code == 429
|
||||||
|
finally:
|
||||||
|
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_network_host_inventory_deferred_when_queue_backlogged(client):
|
||||||
|
hunt_id = "deferred-hunt"
|
||||||
|
inventory_cache.invalidate(hunt_id)
|
||||||
|
inventory_cache.clear_building(hunt_id)
|
||||||
|
|
||||||
|
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||||
|
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||||
|
try:
|
||||||
|
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||||
|
assert res.status_code == 202
|
||||||
|
body = res.json()
|
||||||
|
assert body["status"] == "deferred"
|
||||||
|
finally:
|
||||||
|
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||||
203
backend/tests/test_new_features.py
Normal file
203
backend/tests/test_new_features.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class TestMitreRoutes:
|
||||||
|
"""Tests for /api/mitre endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mitre_coverage_empty(self, client):
|
||||||
|
resp = await client.get("/api/mitre/coverage")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "tactics" in data
|
||||||
|
assert "technique_count" in data
|
||||||
|
assert data["technique_count"] == 0
|
||||||
|
assert len(data["tactics"]) == 14 # 14 MITRE tactics
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mitre_coverage_with_hunt_filter(self, client):
|
||||||
|
resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["technique_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimelineRoutes:
|
||||||
|
"""Tests for /api/timeline endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeline_hunt_not_found(self, client):
|
||||||
|
resp = await client.get("/api/timeline/hunt/nonexistent")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeline_with_hunt(self, client):
|
||||||
|
# Create a hunt first
|
||||||
|
hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"})
|
||||||
|
assert hunt_resp.status_code in (200, 201)
|
||||||
|
hunt_id = hunt_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = await client.get(f"/api/timeline/hunt/{hunt_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["hunt_id"] == hunt_id
|
||||||
|
assert "events" in data
|
||||||
|
assert "datasets" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlaybookRoutes:
|
||||||
|
"""Tests for /api/playbooks endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_playbooks_empty(self, client):
|
||||||
|
resp = await client.get("/api/playbooks")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["playbooks"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_templates(self, client):
|
||||||
|
resp = await client.get("/api/playbooks/templates")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
templates = resp.json()["templates"]
|
||||||
|
assert len(templates) >= 2
|
||||||
|
assert templates[0]["name"] == "Standard Threat Hunt"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_playbook(self, client):
|
||||||
|
resp = await client.post("/api/playbooks", json={
|
||||||
|
"name": "My Investigation",
|
||||||
|
"description": "Test playbook",
|
||||||
|
"steps": [
|
||||||
|
{"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"},
|
||||||
|
{"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "My Investigation"
|
||||||
|
assert len(data["steps"]) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_playbook_crud(self, client):
|
||||||
|
# Create
|
||||||
|
resp = await client.post("/api/playbooks", json={
|
||||||
|
"name": "CRUD Test",
|
||||||
|
"steps": [{"title": "Do something"}],
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
pb_id = resp.json()["id"]
|
||||||
|
|
||||||
|
# Get
|
||||||
|
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["name"] == "CRUD Test"
|
||||||
|
assert len(resp.json()["steps"]) == 1
|
||||||
|
|
||||||
|
# Update
|
||||||
|
resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
resp = await client.delete(f"/api/playbooks/{pb_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_playbook_step_completion(self, client):
|
||||||
|
# Create with step
|
||||||
|
resp = await client.post("/api/playbooks", json={
|
||||||
|
"name": "Step Test",
|
||||||
|
"steps": [{"title": "Task 1"}],
|
||||||
|
})
|
||||||
|
pb_id = resp.json()["id"]
|
||||||
|
|
||||||
|
# Get to find step ID
|
||||||
|
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||||
|
steps = resp.json()["steps"]
|
||||||
|
step_id = steps[0]["id"]
|
||||||
|
assert steps[0]["is_completed"] is False
|
||||||
|
|
||||||
|
# Mark complete
|
||||||
|
resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["is_completed"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestSavedSearchRoutes:
|
||||||
|
"""Tests for /api/searches endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_empty(self, client):
|
||||||
|
resp = await client.get("/api/searches")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["searches"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_saved_search(self, client):
|
||||||
|
resp = await client.post("/api/searches", json={
|
||||||
|
"name": "Suspicious IPs",
|
||||||
|
"search_type": "ioc_search",
|
||||||
|
"query_params": {"ioc_value": "203.0.113"},
|
||||||
|
})
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "Suspicious IPs"
|
||||||
|
assert data["search_type"] == "ioc_search"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_crud(self, client):
|
||||||
|
# Create
|
||||||
|
resp = await client.post("/api/searches", json={
|
||||||
|
"name": "Test Query",
|
||||||
|
"search_type": "keyword_scan",
|
||||||
|
"query_params": {"theme": "malware"},
|
||||||
|
})
|
||||||
|
s_id = resp.json()["id"]
|
||||||
|
|
||||||
|
# Get
|
||||||
|
resp = await client.get(f"/api/searches/{s_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Update
|
||||||
|
resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Run
|
||||||
|
resp = await client.post(f"/api/searches/{s_id}/run")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "result_count" in data
|
||||||
|
assert "delta" in data
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
resp = await client.delete(f"/api/searches/{s_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestStixExport:
|
||||||
|
"""Tests for /api/export/stix endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stix_export_hunt_not_found(self, client):
|
||||||
|
resp = await client.get("/api/export/stix/nonexistent-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stix_export_empty_hunt(self, client):
|
||||||
|
"""Export from a real hunt with no data returns valid but minimal bundle."""
|
||||||
|
hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"})
|
||||||
|
assert hunt_resp.status_code in (200, 201)
|
||||||
|
hunt_id = hunt_resp.json()["id"]
|
||||||
|
|
||||||
|
resp = await client.get(f"/api/export/stix/{hunt_id}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["type"] == "bundle"
|
||||||
|
assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle
|
||||||
|
assert "objects" in data
|
||||||
|
# At minimum should have the identity object
|
||||||
|
types = [o["type"] for o in data["objects"]]
|
||||||
|
assert "identity" in types
|
||||||
|
|
||||||
Binary file not shown.
@@ -7,24 +7,24 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
environment:
|
environment:
|
||||||
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
|
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
|
||||||
TH_WILE_HOST: "100.110.190.12"
|
TH_WILE_HOST: "100.110.190.12"
|
||||||
TH_ROADRUNNER_HOST: "100.110.190.11"
|
TH_ROADRUNNER_HOST: "100.110.190.11"
|
||||||
TH_OLLAMA_PORT: "11434"
|
TH_OLLAMA_PORT: "11434"
|
||||||
TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer"
|
TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer"
|
||||||
|
|
||||||
# ── Database ──
|
# ── Database ──
|
||||||
TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db"
|
TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db"
|
||||||
|
|
||||||
# ── Auth ──
|
# ── Auth ──
|
||||||
TH_JWT_SECRET: "change-me-in-production"
|
TH_JWT_SECRET: "change-me-in-production"
|
||||||
|
|
||||||
# ── Enrichment API keys (set your own) ──
|
# ── Enrichment API keys (set your own) ──
|
||||||
# TH_VIRUSTOTAL_API_KEY: ""
|
# TH_VIRUSTOTAL_API_KEY: ""
|
||||||
# TH_ABUSEIPDB_API_KEY: ""
|
# TH_ABUSEIPDB_API_KEY: ""
|
||||||
# TH_SHODAN_API_KEY: ""
|
# TH_SHODAN_API_KEY: ""
|
||||||
|
|
||||||
# ── Agent behaviour ──
|
# ── Agent behaviour ──
|
||||||
TH_AGENT_MAX_TOKENS: "4096"
|
TH_AGENT_MAX_TOKENS: "4096"
|
||||||
TH_AGENT_TEMPERATURE: "0.3"
|
TH_AGENT_TEMPERATURE: "0.3"
|
||||||
volumes:
|
volumes:
|
||||||
@@ -51,7 +51,7 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- threathunt
|
- threathunt
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:3000/"]
|
test: ["CMD", "curl", "-f", "http://127.0.0.1:3000/"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
|
|||||||
350
fix_all.py
Normal file
350
fix_all.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""Fix all critical issues: DB locking, keyword scan, network map."""
|
||||||
|
import os, re
|
||||||
|
|
||||||
|
ROOT = r"D:\Projects\Dev\ThreatHunt"
|
||||||
|
|
||||||
|
def fix_file(filepath, replacements):
|
||||||
|
"""Apply text replacements to a file."""
|
||||||
|
path = os.path.join(ROOT, filepath)
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
for old, new, desc in replacements:
|
||||||
|
if old in content:
|
||||||
|
content = content.replace(old, new, 1)
|
||||||
|
print(f" OK: {desc}")
|
||||||
|
else:
|
||||||
|
print(f" SKIP: {desc} (pattern not found)")
|
||||||
|
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FIX 1: Database engine - NullPool instead of StaticPool
|
||||||
|
# ================================================================
|
||||||
|
print("\n=== FIX 1: Database engine (NullPool + higher timeouts) ===")
|
||||||
|
|
||||||
|
engine_path = os.path.join(ROOT, "backend", "app", "db", "engine.py")
|
||||||
|
with open(engine_path, "r", encoding="utf-8") as f:
|
||||||
|
engine_content = f.read()
|
||||||
|
|
||||||
|
new_engine = '''"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
|
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||||
|
|
||||||
|
_engine_kwargs: dict = dict(
|
||||||
|
echo=settings.DEBUG,
|
||||||
|
future=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_sqlite:
|
||||||
|
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||||
|
# NullPool: each session gets its own connection.
|
||||||
|
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
_engine_kwargs["poolclass"] = NullPool
|
||||||
|
else:
|
||||||
|
_engine_kwargs["pool_size"] = 5
|
||||||
|
_engine_kwargs["max_overflow"] = 10
|
||||||
|
|
||||||
|
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||||
|
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||||
|
if _is_sqlite:
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
cursor.execute("PRAGMA busy_timeout=30000")
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
async_session_factory = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Alias expected by other modules
|
||||||
|
async_session = async_session_factory
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Base class for all ORM models."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||||
|
"""FastAPI dependency that yields an async DB session."""
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db() -> None:
|
||||||
|
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
||||||
|
async def dispose_db() -> None:
|
||||||
|
"""Dispose of the engine on shutdown."""
|
||||||
|
await engine.dispose()
|
||||||
|
'''
|
||||||
|
|
||||||
|
with open(engine_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(new_engine)
|
||||||
|
print(" OK: Replaced StaticPool with NullPool")
|
||||||
|
print(" OK: Increased busy_timeout 5000 -> 30000ms")
|
||||||
|
print(" OK: Added check_same_thread=False")
|
||||||
|
print(" OK: Connection timeout 30 -> 60s")
|
||||||
|
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FIX 2: Keyword scan endpoint - make POST non-blocking (background job)
|
||||||
|
# ================================================================
|
||||||
|
print("\n=== FIX 2: Keyword scan endpoint -> background job ===")
|
||||||
|
|
||||||
|
kw_path = os.path.join(ROOT, "backend", "app", "api", "routes", "keywords.py")
|
||||||
|
with open(kw_path, "r", encoding="utf-8") as f:
|
||||||
|
kw_content = f.read()
|
||||||
|
|
||||||
|
# Replace the scan endpoint to be non-blocking
|
||||||
|
old_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||||
|
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Run AUP keyword scan across selected data sources."""
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(
|
||||||
|
dataset_ids=body.dataset_ids,
|
||||||
|
theme_ids=body.theme_ids,
|
||||||
|
scan_hunts=body.scan_hunts,
|
||||||
|
scan_annotations=body.scan_annotations,
|
||||||
|
scan_messages=body.scan_messages,
|
||||||
|
)
|
||||||
|
return result'''
|
||||||
|
|
||||||
|
new_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||||
|
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
"""Run AUP keyword scan across selected data sources.
|
||||||
|
|
||||||
|
Uses a dedicated DB session separate from the request session
|
||||||
|
to avoid blocking other API requests on SQLite.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
async with async_session_factory() as scan_db:
|
||||||
|
scanner = KeywordScanner(scan_db)
|
||||||
|
result = await scanner.scan(
|
||||||
|
dataset_ids=body.dataset_ids,
|
||||||
|
theme_ids=body.theme_ids,
|
||||||
|
scan_hunts=body.scan_hunts,
|
||||||
|
scan_annotations=body.scan_annotations,
|
||||||
|
scan_messages=body.scan_messages,
|
||||||
|
)
|
||||||
|
return result'''
|
||||||
|
|
||||||
|
if old_scan in kw_content:
|
||||||
|
kw_content = kw_content.replace(old_scan, new_scan, 1)
|
||||||
|
print(" OK: Scan endpoint uses dedicated DB session")
|
||||||
|
else:
|
||||||
|
print(" SKIP: Scan endpoint pattern not found")
|
||||||
|
|
||||||
|
# Also fix quick_scan
|
||||||
|
old_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||||
|
async def quick_scan(
|
||||||
|
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Quick scan a single dataset with all enabled themes."""
|
||||||
|
scanner = KeywordScanner(db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
return result'''
|
||||||
|
|
||||||
|
new_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||||
|
async def quick_scan(
|
||||||
|
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""Quick scan a single dataset with all enabled themes."""
|
||||||
|
from app.db import async_session_factory
|
||||||
|
async with async_session_factory() as scan_db:
|
||||||
|
scanner = KeywordScanner(scan_db)
|
||||||
|
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||||
|
return result'''
|
||||||
|
|
||||||
|
if old_quick in kw_content:
|
||||||
|
kw_content = kw_content.replace(old_quick, new_quick, 1)
|
||||||
|
print(" OK: Quick scan uses dedicated DB session")
|
||||||
|
else:
|
||||||
|
print(" SKIP: Quick scan pattern not found")
|
||||||
|
|
||||||
|
with open(kw_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(kw_content)
|
||||||
|
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FIX 3: Scanner service - smaller batches, yield between batches
|
||||||
|
# ================================================================
|
||||||
|
print("\n=== FIX 3: Scanner service - smaller batches + async yield ===")
|
||||||
|
|
||||||
|
scanner_path = os.path.join(ROOT, "backend", "app", "services", "scanner.py")
|
||||||
|
with open(scanner_path, "r", encoding="utf-8") as f:
|
||||||
|
scanner_content = f.read()
|
||||||
|
|
||||||
|
# Change batch size and add yield between batches
|
||||||
|
old_batch = "BATCH_SIZE = 500"
|
||||||
|
new_batch = "BATCH_SIZE = 200"
|
||||||
|
|
||||||
|
if old_batch in scanner_content:
|
||||||
|
scanner_content = scanner_content.replace(old_batch, new_batch, 1)
|
||||||
|
print(" OK: Reduced batch size 500 -> 200")
|
||||||
|
|
||||||
|
# Add asyncio.sleep(0) between batches to yield to other tasks
|
||||||
|
old_batch_loop = ''' offset += BATCH_SIZE
|
||||||
|
if len(rows) < BATCH_SIZE:
|
||||||
|
break'''
|
||||||
|
|
||||||
|
new_batch_loop = ''' offset += BATCH_SIZE
|
||||||
|
# Yield to event loop between batches so other requests aren't starved
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
if len(rows) < BATCH_SIZE:
|
||||||
|
break'''
|
||||||
|
|
||||||
|
if old_batch_loop in scanner_content:
|
||||||
|
scanner_content = scanner_content.replace(old_batch_loop, new_batch_loop, 1)
|
||||||
|
print(" OK: Added async yield between scan batches")
|
||||||
|
else:
|
||||||
|
print(" SKIP: Batch loop pattern not found")
|
||||||
|
|
||||||
|
with open(scanner_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(scanner_content)
|
||||||
|
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FIX 4: Job queue workers - increase from 3 to 5
|
||||||
|
# ================================================================
|
||||||
|
print("\n=== FIX 4: Job queue - more workers ===")
|
||||||
|
|
||||||
|
jq_path = os.path.join(ROOT, "backend", "app", "services", "job_queue.py")
|
||||||
|
with open(jq_path, "r", encoding="utf-8") as f:
|
||||||
|
jq_content = f.read()
|
||||||
|
|
||||||
|
old_workers = "job_queue = JobQueue(max_workers=3)"
|
||||||
|
new_workers = "job_queue = JobQueue(max_workers=5)"
|
||||||
|
|
||||||
|
if old_workers in jq_content:
|
||||||
|
jq_content = jq_content.replace(old_workers, new_workers, 1)
|
||||||
|
print(" OK: Workers 3 -> 5")
|
||||||
|
|
||||||
|
with open(jq_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(jq_content)
|
||||||
|
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# FIX 5: main.py - always re-run pipeline on startup for ALL datasets
|
||||||
|
# ================================================================
|
||||||
|
print("\n=== FIX 5: Startup reprocessing - all datasets, not just 'ready' ===")
|
||||||
|
|
||||||
|
main_path = os.path.join(ROOT, "backend", "app", "main.py")
|
||||||
|
with open(main_path, "r", encoding="utf-8") as f:
|
||||||
|
main_content = f.read()
|
||||||
|
|
||||||
|
# The current startup only reprocesses datasets with status="ready"
|
||||||
|
# But after previous runs, they're all "completed" - so nothing happens
|
||||||
|
# Fix: reprocess datasets that have NO triage/anomaly results in DB
|
||||||
|
old_reprocess = ''' # Reprocess datasets that were never fully processed (status still "ready")
|
||||||
|
async with async_session_factory() as reprocess_db:
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.db.models import Dataset
|
||||||
|
stmt = select(Dataset.id).where(Dataset.processing_status == "ready")
|
||||||
|
result = await reprocess_db.execute(stmt)
|
||||||
|
unprocessed_ids = [row[0] for row in result.all()]
|
||||||
|
for ds_id in unprocessed_ids:
|
||||||
|
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||||
|
if unprocessed_ids:
|
||||||
|
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||||
|
# Mark them as processing
|
||||||
|
async with async_session_factory() as update_db:
|
||||||
|
from sqlalchemy import update
|
||||||
|
from app.db.models import Dataset
|
||||||
|
await update_db.execute(
|
||||||
|
update(Dataset)
|
||||||
|
.where(Dataset.id.in_(unprocessed_ids))
|
||||||
|
.values(processing_status="processing")
|
||||||
|
)
|
||||||
|
await update_db.commit()'''
|
||||||
|
|
||||||
|
new_reprocess = ''' # Check which datasets still need processing
|
||||||
|
# (no anomaly results = never fully processed)
|
||||||
|
async with async_session_factory() as reprocess_db:
|
||||||
|
from sqlalchemy import select, exists
|
||||||
|
from app.db.models import Dataset, AnomalyResult
|
||||||
|
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||||
|
has_anomaly = (
|
||||||
|
select(AnomalyResult.id)
|
||||||
|
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||||
|
.limit(1)
|
||||||
|
.correlate(Dataset)
|
||||||
|
.exists()
|
||||||
|
)
|
||||||
|
stmt = select(Dataset.id).where(~has_anomaly)
|
||||||
|
result = await reprocess_db.execute(stmt)
|
||||||
|
unprocessed_ids = [row[0] for row in result.all()]
|
||||||
|
|
||||||
|
if unprocessed_ids:
|
||||||
|
for ds_id in unprocessed_ids:
|
||||||
|
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||||
|
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||||
|
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||||
|
async with async_session_factory() as update_db:
|
||||||
|
from sqlalchemy import update
|
||||||
|
from app.db.models import Dataset
|
||||||
|
await update_db.execute(
|
||||||
|
update(Dataset)
|
||||||
|
.where(Dataset.id.in_(unprocessed_ids))
|
||||||
|
.values(processing_status="processing")
|
||||||
|
)
|
||||||
|
await update_db.commit()
|
||||||
|
else:
|
||||||
|
logger.info("All datasets already processed - skipping startup pipeline")'''
|
||||||
|
|
||||||
|
if old_reprocess in main_content:
|
||||||
|
main_content = main_content.replace(old_reprocess, new_reprocess, 1)
|
||||||
|
print(" OK: Startup checks for actual results, not just status field")
|
||||||
|
else:
|
||||||
|
print(" SKIP: Reprocess block not found")
|
||||||
|
|
||||||
|
with open(main_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(main_content)
|
||||||
|
|
||||||
|
|
||||||
|
print("\n=== ALL FIXES APPLIED ===")
|
||||||
30
fix_keywords.py
Normal file
30
fix_keywords.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
path = "backend/app/api/routes/keywords.py"
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
c = f.read()
|
||||||
|
|
||||||
|
# Fix POST /scan - remove dedicated session, use injected db
|
||||||
|
old1 = ' """Run AUP keyword scan across selected data sources.\n \n Uses a dedicated DB session separate from the request session\n to avoid blocking other API requests on SQLite.\n """\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
|
||||||
|
|
||||||
|
new1 = ' """Run AUP keyword scan across selected data sources."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
|
||||||
|
|
||||||
|
if old1 in c:
|
||||||
|
c = c.replace(old1, new1, 1)
|
||||||
|
print("OK: reverted POST /scan")
|
||||||
|
else:
|
||||||
|
print("SKIP: POST /scan not found")
|
||||||
|
|
||||||
|
# Fix GET /scan/quick
|
||||||
|
old2 = ' """Quick scan a single dataset with all enabled themes."""\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
|
||||||
|
|
||||||
|
new2 = ' """Quick scan a single dataset with all enabled themes."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
|
||||||
|
|
||||||
|
if old2 in c:
|
||||||
|
c = c.replace(old2, new2, 1)
|
||||||
|
print("OK: reverted GET /scan/quick")
|
||||||
|
else:
|
||||||
|
print("SKIP: GET /scan/quick not found")
|
||||||
|
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(c)
|
||||||
@@ -16,6 +16,12 @@ server {
|
|||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
proxy_read_timeout 300s;
|
proxy_read_timeout 300s;
|
||||||
|
|
||||||
|
# SSE streaming support for agent assist
|
||||||
|
proxy_buffering off;
|
||||||
|
proxy_cache off;
|
||||||
|
proxy_set_header Connection '';
|
||||||
|
chunked_transfer_encoding off;
|
||||||
}
|
}
|
||||||
|
|
||||||
# SPA fallback serve index.html for all non-file routes
|
# SPA fallback serve index.html for all non-file routes
|
||||||
|
|||||||
378
frontend/package-lock.json
generated
378
frontend/package-lock.json
generated
@@ -18,7 +18,8 @@
|
|||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-router-dom": "^7.13.0",
|
"react-router-dom": "^7.13.0",
|
||||||
"react-scripts": "5.0.1"
|
"react-scripts": "5.0.1",
|
||||||
|
"recharts": "^3.7.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/react": "^18.2.0",
|
"@types/react": "^18.2.0",
|
||||||
@@ -3476,6 +3477,42 @@
|
|||||||
"url": "https://opencollective.com/popperjs"
|
"url": "https://opencollective.com/popperjs"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@reduxjs/toolkit": {
|
||||||
|
"version": "2.11.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/@reduxjs/toolkit/-/toolkit-2.11.2.tgz",
|
||||||
|
"integrity": "sha512-Kd6kAHTA6/nUpp8mySPqj3en3dm0tdMIgbttnQ1xFMVpufoj+ADi8pXLBsd4xzTRHQa7t/Jv8W5UnCuW4kuWMQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@standard-schema/spec": "^1.0.0",
|
||||||
|
"@standard-schema/utils": "^0.3.0",
|
||||||
|
"immer": "^11.0.0",
|
||||||
|
"redux": "^5.0.1",
|
||||||
|
"redux-thunk": "^3.1.0",
|
||||||
|
"reselect": "^5.1.0"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"react": "^16.9.0 || ^17.0.0 || ^18 || ^19",
|
||||||
|
"react-redux": "^7.2.1 || ^8.1.3 || ^9.0.0"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"react": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"react-redux": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@reduxjs/toolkit/node_modules/immer": {
|
||||||
|
"version": "11.1.4",
|
||||||
|
"resolved": "https://registry.npmjs.org/immer/-/immer-11.1.4.tgz",
|
||||||
|
"integrity": "sha512-XREFCPo6ksxVzP4E0ekD5aMdf8WMwmdNaz6vuvxgI40UaEiu6q3p8X52aU6GdyvLY3XXX/8R7JOTXStz/nBbRw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/immer"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@rollup/plugin-babel": {
|
"node_modules/@rollup/plugin-babel": {
|
||||||
"version": "5.3.1",
|
"version": "5.3.1",
|
||||||
"resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz",
|
"resolved": "https://registry.npmjs.org/@rollup/plugin-babel/-/plugin-babel-5.3.1.tgz",
|
||||||
@@ -3591,6 +3628,18 @@
|
|||||||
"@sinonjs/commons": "^1.7.0"
|
"@sinonjs/commons": "^1.7.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@standard-schema/spec": {
|
||||||
|
"version": "1.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz",
|
||||||
|
"integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@standard-schema/utils": {
|
||||||
|
"version": "0.3.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/@standard-schema/utils/-/utils-0.3.0.tgz",
|
||||||
|
"integrity": "sha512-e7Mew686owMaPJVNNLs55PUvgz371nKgwsc4vxE49zsODpJEnxgxRo2y/OKrqueavXgZNMDVj3DdHFlaSAeU8g==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/@surma/rollup-plugin-off-main-thread": {
|
"node_modules/@surma/rollup-plugin-off-main-thread": {
|
||||||
"version": "2.2.3",
|
"version": "2.2.3",
|
||||||
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
|
"resolved": "https://registry.npmjs.org/@surma/rollup-plugin-off-main-thread/-/rollup-plugin-off-main-thread-2.2.3.tgz",
|
||||||
@@ -3921,6 +3970,69 @@
|
|||||||
"@types/node": "*"
|
"@types/node": "*"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@types/d3-array": {
|
||||||
|
"version": "3.2.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz",
|
||||||
|
"integrity": "sha512-hOLWVbm7uRza0BYXpIIW5pxfrKe0W+D5lrFiAEYR+pb6w3N2SwSMaJbXdUfSEv+dT4MfHBLtn5js0LAWaO6otw==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-color": {
|
||||||
|
"version": "3.1.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz",
|
||||||
|
"integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-ease": {
|
||||||
|
"version": "3.0.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-ease/-/d3-ease-3.0.2.tgz",
|
||||||
|
"integrity": "sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-interpolate": {
|
||||||
|
"version": "3.0.4",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz",
|
||||||
|
"integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/d3-color": "*"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-path": {
|
||||||
|
"version": "3.1.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-path/-/d3-path-3.1.1.tgz",
|
||||||
|
"integrity": "sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-scale": {
|
||||||
|
"version": "4.0.9",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-scale/-/d3-scale-4.0.9.tgz",
|
||||||
|
"integrity": "sha512-dLmtwB8zkAeO/juAMfnV+sItKjlsw2lKdZVVy6LRr0cBmegxSABiLEpGVmSJJ8O08i4+sGR6qQtb6WtuwJdvVw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/d3-time": "*"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-shape": {
|
||||||
|
"version": "3.1.8",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-shape/-/d3-shape-3.1.8.tgz",
|
||||||
|
"integrity": "sha512-lae0iWfcDeR7qt7rA88BNiqdvPS5pFVPpo5OfjElwNaT2yyekbM0C9vK+yqBqEmHr6lDkRnYNoTBYlAgJa7a4w==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/d3-path": "*"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-time": {
|
||||||
|
"version": "3.0.4",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-time/-/d3-time-3.0.4.tgz",
|
||||||
|
"integrity": "sha512-yuzZug1nkAAaBlBBikKZTgzCeA+k1uy4ZFwWANOfKw5z5LRhV0gNA7gNkKm7HoK+HRN0wX3EkxGk0fpbWhmB7g==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/@types/d3-timer": {
|
||||||
|
"version": "3.0.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/d3-timer/-/d3-timer-3.0.2.tgz",
|
||||||
|
"integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/@types/eslint": {
|
"node_modules/@types/eslint": {
|
||||||
"version": "8.56.12",
|
"version": "8.56.12",
|
||||||
"resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz",
|
"resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz",
|
||||||
@@ -4246,6 +4358,12 @@
|
|||||||
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
|
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/@types/use-sync-external-store": {
|
||||||
|
"version": "0.0.6",
|
||||||
|
"resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz",
|
||||||
|
"integrity": "sha512-zFDAD+tlpf2r4asuHEj0XH6pY6i0g5NeAHPn+15wk3BV6JA69eERFXC1gyGThDkVa1zCyKr5jox1+2LbV/AMLg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/@types/ws": {
|
"node_modules/@types/ws": {
|
||||||
"version": "8.18.1",
|
"version": "8.18.1",
|
||||||
"resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
|
"resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
|
||||||
@@ -6757,6 +6875,127 @@
|
|||||||
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/d3-array": {
|
||||||
|
"version": "3.2.4",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz",
|
||||||
|
"integrity": "sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"internmap": "1 - 2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-color": {
|
||||||
|
"version": "3.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-color/-/d3-color-3.1.0.tgz",
|
||||||
|
"integrity": "sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==",
|
||||||
|
"license": "ISC",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-ease": {
|
||||||
|
"version": "3.0.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz",
|
||||||
|
"integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==",
|
||||||
|
"license": "BSD-3-Clause",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-format": {
|
||||||
|
"version": "3.1.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-format/-/d3-format-3.1.2.tgz",
|
||||||
|
"integrity": "sha512-AJDdYOdnyRDV5b6ArilzCPPwc1ejkHcoyFarqlPqT7zRYjhavcT3uSrqcMvsgh2CgoPbK3RCwyHaVyxYcP2Arg==",
|
||||||
|
"license": "ISC",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-interpolate": {
|
||||||
|
"version": "3.0.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-interpolate/-/d3-interpolate-3.0.1.tgz",
|
||||||
|
"integrity": "sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"d3-color": "1 - 3"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-path": {
|
||||||
|
"version": "3.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-path/-/d3-path-3.1.0.tgz",
|
||||||
|
"integrity": "sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==",
|
||||||
|
"license": "ISC",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-scale": {
|
||||||
|
"version": "4.0.2",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-scale/-/d3-scale-4.0.2.tgz",
|
||||||
|
"integrity": "sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"d3-array": "2.10.0 - 3",
|
||||||
|
"d3-format": "1 - 3",
|
||||||
|
"d3-interpolate": "1.2.0 - 3",
|
||||||
|
"d3-time": "2.1.1 - 3",
|
||||||
|
"d3-time-format": "2 - 4"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-shape": {
|
||||||
|
"version": "3.2.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz",
|
||||||
|
"integrity": "sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"d3-path": "^3.1.0"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-time": {
|
||||||
|
"version": "3.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-time/-/d3-time-3.1.0.tgz",
|
||||||
|
"integrity": "sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"d3-array": "2 - 3"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-time-format": {
|
||||||
|
"version": "4.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-time-format/-/d3-time-format-4.1.0.tgz",
|
||||||
|
"integrity": "sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==",
|
||||||
|
"license": "ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"d3-time": "1 - 3"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/d3-timer": {
|
||||||
|
"version": "3.0.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/d3-timer/-/d3-timer-3.0.1.tgz",
|
||||||
|
"integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==",
|
||||||
|
"license": "ISC",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/damerau-levenshtein": {
|
"node_modules/damerau-levenshtein": {
|
||||||
"version": "1.0.8",
|
"version": "1.0.8",
|
||||||
"resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz",
|
"resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz",
|
||||||
@@ -6851,6 +7090,12 @@
|
|||||||
"integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==",
|
"integrity": "sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==",
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/decimal.js-light": {
|
||||||
|
"version": "2.5.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/decimal.js-light/-/decimal.js-light-2.5.1.tgz",
|
||||||
|
"integrity": "sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/dedent": {
|
"node_modules/dedent": {
|
||||||
"version": "0.7.0",
|
"version": "0.7.0",
|
||||||
"resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz",
|
"resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz",
|
||||||
@@ -7484,6 +7729,16 @@
|
|||||||
"url": "https://github.com/sponsors/ljharb"
|
"url": "https://github.com/sponsors/ljharb"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/es-toolkit": {
|
||||||
|
"version": "1.44.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/es-toolkit/-/es-toolkit-1.44.0.tgz",
|
||||||
|
"integrity": "sha512-6penXeZalaV88MM3cGkFZZfOoLGWshWWfdy0tWw/RlVVyhvMaWSBTOvXNeiW3e5FwdS5ePW0LGEu17zT139ktg==",
|
||||||
|
"license": "MIT",
|
||||||
|
"workspaces": [
|
||||||
|
"docs",
|
||||||
|
"benchmarks"
|
||||||
|
]
|
||||||
|
},
|
||||||
"node_modules/escalade": {
|
"node_modules/escalade": {
|
||||||
"version": "3.2.0",
|
"version": "3.2.0",
|
||||||
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
|
||||||
@@ -9645,6 +9900,15 @@
|
|||||||
"node": ">= 0.4"
|
"node": ">= 0.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/internmap": {
|
||||||
|
"version": "2.0.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/internmap/-/internmap-2.0.3.tgz",
|
||||||
|
"integrity": "sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==",
|
||||||
|
"license": "ISC",
|
||||||
|
"engines": {
|
||||||
|
"node": ">=12"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/ipaddr.js": {
|
"node_modules/ipaddr.js": {
|
||||||
"version": "2.3.0",
|
"version": "2.3.0",
|
||||||
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-2.3.0.tgz",
|
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-2.3.0.tgz",
|
||||||
@@ -14239,6 +14503,29 @@
|
|||||||
"integrity": "sha512-W+EWGn2v0ApPKgKKCy/7s7WHXkboGcsrXE+2joLyVxkbyVQfO3MUEaUQDHoSmb8TFFrSKYa9mw64WZHNHSDzYA==",
|
"integrity": "sha512-W+EWGn2v0ApPKgKKCy/7s7WHXkboGcsrXE+2joLyVxkbyVQfO3MUEaUQDHoSmb8TFFrSKYa9mw64WZHNHSDzYA==",
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/react-redux": {
|
||||||
|
"version": "9.2.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz",
|
||||||
|
"integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/use-sync-external-store": "^0.0.6",
|
||||||
|
"use-sync-external-store": "^1.4.0"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@types/react": "^18.2.25 || ^19",
|
||||||
|
"react": "^18.0 || ^19",
|
||||||
|
"redux": "^5.0.0"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"@types/react": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"redux": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/react-refresh": {
|
"node_modules/react-refresh": {
|
||||||
"version": "0.11.0",
|
"version": "0.11.0",
|
||||||
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
|
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
|
||||||
@@ -14410,6 +14697,52 @@
|
|||||||
"node": ">=8.10.0"
|
"node": ">=8.10.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/recharts": {
|
||||||
|
"version": "3.7.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/recharts/-/recharts-3.7.0.tgz",
|
||||||
|
"integrity": "sha512-l2VCsy3XXeraxIID9fx23eCb6iCBsxUQDnE8tWm6DFdszVAO7WVY/ChAD9wVit01y6B2PMupYiMmQwhgPHc9Ew==",
|
||||||
|
"license": "MIT",
|
||||||
|
"workspaces": [
|
||||||
|
"www"
|
||||||
|
],
|
||||||
|
"dependencies": {
|
||||||
|
"@reduxjs/toolkit": "1.x.x || 2.x.x",
|
||||||
|
"clsx": "^2.1.1",
|
||||||
|
"decimal.js-light": "^2.5.1",
|
||||||
|
"es-toolkit": "^1.39.3",
|
||||||
|
"eventemitter3": "^5.0.1",
|
||||||
|
"immer": "^10.1.1",
|
||||||
|
"react-redux": "8.x.x || 9.x.x",
|
||||||
|
"reselect": "5.1.1",
|
||||||
|
"tiny-invariant": "^1.3.3",
|
||||||
|
"use-sync-external-store": "^1.2.2",
|
||||||
|
"victory-vendor": "^37.0.2"
|
||||||
|
},
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0",
|
||||||
|
"react-dom": "^16.0.0 || ^17.0.0 || ^18.0.0 || ^19.0.0",
|
||||||
|
"react-is": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/recharts/node_modules/eventemitter3": {
|
||||||
|
"version": "5.0.4",
|
||||||
|
"resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-5.0.4.tgz",
|
||||||
|
"integrity": "sha512-mlsTRyGaPBjPedk6Bvw+aqbsXDtoAyAzm5MO7JgU+yVRyMQ5O8bD4Kcci7BS85f93veegeCPkL8R4GLClnjLFw==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/recharts/node_modules/immer": {
|
||||||
|
"version": "10.2.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/immer/-/immer-10.2.0.tgz",
|
||||||
|
"integrity": "sha512-d/+XTN3zfODyjr89gM3mPq1WNX2B8pYsu7eORitdwyA2sBubnTl3laYlBk4sXY5FUa5qTZGBDPJICVbvqzjlbw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"funding": {
|
||||||
|
"type": "opencollective",
|
||||||
|
"url": "https://opencollective.com/immer"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/recursive-readdir": {
|
"node_modules/recursive-readdir": {
|
||||||
"version": "2.2.3",
|
"version": "2.2.3",
|
||||||
"resolved": "https://registry.npmjs.org/recursive-readdir/-/recursive-readdir-2.2.3.tgz",
|
"resolved": "https://registry.npmjs.org/recursive-readdir/-/recursive-readdir-2.2.3.tgz",
|
||||||
@@ -14422,6 +14755,21 @@
|
|||||||
"node": ">=6.0.0"
|
"node": ">=6.0.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/redux": {
|
||||||
|
"version": "5.0.1",
|
||||||
|
"resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz",
|
||||||
|
"integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
|
"node_modules/redux-thunk": {
|
||||||
|
"version": "3.1.0",
|
||||||
|
"resolved": "https://registry.npmjs.org/redux-thunk/-/redux-thunk-3.1.0.tgz",
|
||||||
|
"integrity": "sha512-NW2r5T6ksUKXCabzhL9z+h206HQw/NJkcLm1GPImRQ8IzfXwRGqjVhKJGauHirT0DAuyy6hjdnMZaRoAcy0Klw==",
|
||||||
|
"license": "MIT",
|
||||||
|
"peerDependencies": {
|
||||||
|
"redux": "^5.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/reflect.getprototypeof": {
|
"node_modules/reflect.getprototypeof": {
|
||||||
"version": "1.0.10",
|
"version": "1.0.10",
|
||||||
"resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz",
|
"resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz",
|
||||||
@@ -16329,6 +16677,12 @@
|
|||||||
"integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==",
|
"integrity": "sha512-eHY7nBftgThBqOyHGVN+l8gF0BucP09fMo0oO/Lb0w1OF80dJv+lDVpXG60WMQvkcxAkNybKsrEIE3ZtKGmPrA==",
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
|
"node_modules/tiny-invariant": {
|
||||||
|
"version": "1.3.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.3.tgz",
|
||||||
|
"integrity": "sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==",
|
||||||
|
"license": "MIT"
|
||||||
|
},
|
||||||
"node_modules/tinyglobby": {
|
"node_modules/tinyglobby": {
|
||||||
"version": "0.2.15",
|
"version": "0.2.15",
|
||||||
"resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz",
|
"resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz",
|
||||||
@@ -16902,6 +17256,28 @@
|
|||||||
"node": ">= 0.8"
|
"node": ">= 0.8"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/victory-vendor": {
|
||||||
|
"version": "37.3.6",
|
||||||
|
"resolved": "https://registry.npmjs.org/victory-vendor/-/victory-vendor-37.3.6.tgz",
|
||||||
|
"integrity": "sha512-SbPDPdDBYp+5MJHhBCAyI7wKM3d5ivekigc2Dk2s7pgbZ9wIgIBYGVw4zGHBml/qTFbexrofXW6Gu4noGxrOwQ==",
|
||||||
|
"license": "MIT AND ISC",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/d3-array": "^3.0.3",
|
||||||
|
"@types/d3-ease": "^3.0.0",
|
||||||
|
"@types/d3-interpolate": "^3.0.1",
|
||||||
|
"@types/d3-scale": "^4.0.2",
|
||||||
|
"@types/d3-shape": "^3.1.0",
|
||||||
|
"@types/d3-time": "^3.0.0",
|
||||||
|
"@types/d3-timer": "^3.0.0",
|
||||||
|
"d3-array": "^3.1.6",
|
||||||
|
"d3-ease": "^3.0.1",
|
||||||
|
"d3-interpolate": "^3.0.1",
|
||||||
|
"d3-scale": "^4.0.2",
|
||||||
|
"d3-shape": "^3.1.0",
|
||||||
|
"d3-time": "^3.0.0",
|
||||||
|
"d3-timer": "^3.0.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/w3c-hr-time": {
|
"node_modules/w3c-hr-time": {
|
||||||
"version": "1.0.2",
|
"version": "1.0.2",
|
||||||
"resolved": "https://registry.npmjs.org/w3c-hr-time/-/w3c-hr-time-1.0.2.tgz",
|
"resolved": "https://registry.npmjs.org/w3c-hr-time/-/w3c-hr-time-1.0.2.tgz",
|
||||||
|
|||||||
@@ -13,7 +13,8 @@
|
|||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-router-dom": "^7.13.0",
|
"react-router-dom": "^7.13.0",
|
||||||
"react-scripts": "5.0.1"
|
"react-scripts": "5.0.1",
|
||||||
|
"recharts": "^3.7.0"
|
||||||
},
|
},
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "react-scripts start",
|
"start": "react-scripts start",
|
||||||
|
|||||||
@@ -2,10 +2,11 @@
|
|||||||
* ThreatHunt MUI-powered analyst-assist platform.
|
* ThreatHunt MUI-powered analyst-assist platform.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import React, { useState, useCallback } from 'react';
|
import React, { useState, useCallback, Suspense } from 'react';
|
||||||
import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom';
|
import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom';
|
||||||
import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton,
|
import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton,
|
||||||
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip } from '@mui/material';
|
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip,
|
||||||
|
CircularProgress } from '@mui/material';
|
||||||
import MenuIcon from '@mui/icons-material/Menu';
|
import MenuIcon from '@mui/icons-material/Menu';
|
||||||
import DashboardIcon from '@mui/icons-material/Dashboard';
|
import DashboardIcon from '@mui/icons-material/Dashboard';
|
||||||
import SearchIcon from '@mui/icons-material/Search';
|
import SearchIcon from '@mui/icons-material/Search';
|
||||||
@@ -19,9 +20,14 @@ import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
|
|||||||
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
|
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
|
||||||
import HubIcon from '@mui/icons-material/Hub';
|
import HubIcon from '@mui/icons-material/Hub';
|
||||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||||
|
import TimelineIcon from '@mui/icons-material/Timeline';
|
||||||
|
import PlaylistAddCheckIcon from '@mui/icons-material/PlaylistAddCheck';
|
||||||
|
import BookmarksIcon from '@mui/icons-material/Bookmarks';
|
||||||
|
import ShieldIcon from '@mui/icons-material/Shield';
|
||||||
import { SnackbarProvider } from 'notistack';
|
import { SnackbarProvider } from 'notistack';
|
||||||
import theme from './theme';
|
import theme from './theme';
|
||||||
|
|
||||||
|
/* -- Eager imports (lightweight, always needed) -- */
|
||||||
import Dashboard from './components/Dashboard';
|
import Dashboard from './components/Dashboard';
|
||||||
import HuntManager from './components/HuntManager';
|
import HuntManager from './components/HuntManager';
|
||||||
import DatasetViewer from './components/DatasetViewer';
|
import DatasetViewer from './components/DatasetViewer';
|
||||||
@@ -32,8 +38,14 @@ import AnnotationPanel from './components/AnnotationPanel';
|
|||||||
import HypothesisTracker from './components/HypothesisTracker';
|
import HypothesisTracker from './components/HypothesisTracker';
|
||||||
import CorrelationView from './components/CorrelationView';
|
import CorrelationView from './components/CorrelationView';
|
||||||
import AUPScanner from './components/AUPScanner';
|
import AUPScanner from './components/AUPScanner';
|
||||||
import NetworkMap from './components/NetworkMap';
|
|
||||||
import AnalysisDashboard from './components/AnalysisDashboard';
|
/* -- Lazy imports (heavy: charts, network graph, new feature pages) -- */
|
||||||
|
const NetworkMap = React.lazy(() => import('./components/NetworkMap'));
|
||||||
|
const AnalysisDashboard = React.lazy(() => import('./components/AnalysisDashboard'));
|
||||||
|
const MitreMatrix = React.lazy(() => import('./components/MitreMatrix'));
|
||||||
|
const TimelineView = React.lazy(() => import('./components/TimelineView'));
|
||||||
|
const PlaybookManager = React.lazy(() => import('./components/PlaybookManager'));
|
||||||
|
const SavedSearches = React.lazy(() => import('./components/SavedSearches'));
|
||||||
|
|
||||||
const DRAWER_WIDTH = 240;
|
const DRAWER_WIDTH = 240;
|
||||||
|
|
||||||
@@ -52,8 +64,20 @@ const NAV: NavItem[] = [
|
|||||||
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
|
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
|
||||||
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
||||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||||
|
{ label: 'MITRE Matrix', path: '/mitre', icon: <ShieldIcon /> },
|
||||||
|
{ label: 'Timeline', path: '/timeline', icon: <TimelineIcon /> },
|
||||||
|
{ label: 'Playbooks', path: '/playbooks', icon: <PlaylistAddCheckIcon /> },
|
||||||
|
{ label: 'Saved Searches', path: '/saved-searches', icon: <BookmarksIcon /> },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
function LazyFallback() {
|
||||||
|
return (
|
||||||
|
<Box sx={{ display: 'flex', justifyContent: 'center', alignItems: 'center', minHeight: 200 }}>
|
||||||
|
<CircularProgress />
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function Shell() {
|
function Shell() {
|
||||||
const [open, setOpen] = useState(true);
|
const [open, setOpen] = useState(true);
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
@@ -72,7 +96,7 @@ function Shell() {
|
|||||||
<Typography variant="h6" noWrap sx={{ flexGrow: 1 }}>
|
<Typography variant="h6" noWrap sx={{ flexGrow: 1 }}>
|
||||||
ThreatHunt
|
ThreatHunt
|
||||||
</Typography>
|
</Typography>
|
||||||
<Chip label="v0.3.0" size="small" color="primary" variant="outlined" />
|
<Chip label="v0.4.0" size="small" color="primary" variant="outlined" />
|
||||||
</Toolbar>
|
</Toolbar>
|
||||||
</AppBar>
|
</AppBar>
|
||||||
|
|
||||||
@@ -107,6 +131,7 @@ function Shell() {
|
|||||||
ml: open ? 0 : `-${DRAWER_WIDTH}px`,
|
ml: open ? 0 : `-${DRAWER_WIDTH}px`,
|
||||||
transition: 'margin 225ms cubic-bezier(0,0,0.2,1)',
|
transition: 'margin 225ms cubic-bezier(0,0,0.2,1)',
|
||||||
}}>
|
}}>
|
||||||
|
<Suspense fallback={<LazyFallback />}>
|
||||||
<Routes>
|
<Routes>
|
||||||
<Route path="/" element={<Dashboard />} />
|
<Route path="/" element={<Dashboard />} />
|
||||||
<Route path="/hunts" element={<HuntManager />} />
|
<Route path="/hunts" element={<HuntManager />} />
|
||||||
@@ -120,7 +145,12 @@ function Shell() {
|
|||||||
<Route path="/correlation" element={<CorrelationView />} />
|
<Route path="/correlation" element={<CorrelationView />} />
|
||||||
<Route path="/network" element={<NetworkMap />} />
|
<Route path="/network" element={<NetworkMap />} />
|
||||||
<Route path="/aup" element={<AUPScanner />} />
|
<Route path="/aup" element={<AUPScanner />} />
|
||||||
|
<Route path="/mitre" element={<MitreMatrix />} />
|
||||||
|
<Route path="/timeline" element={<TimelineView />} />
|
||||||
|
<Route path="/playbooks" element={<PlaybookManager />} />
|
||||||
|
<Route path="/saved-searches" element={<SavedSearches />} />
|
||||||
</Routes>
|
</Routes>
|
||||||
|
</Suspense>
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -71,6 +71,20 @@ export interface Hunt {
|
|||||||
dataset_count: number; hypothesis_count: number;
|
dataset_count: number; hypothesis_count: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface HuntProgress {
|
||||||
|
hunt_id: string;
|
||||||
|
status: 'idle' | 'processing' | 'ready';
|
||||||
|
progress_percent: number;
|
||||||
|
dataset_total: number;
|
||||||
|
dataset_completed: number;
|
||||||
|
dataset_processing: number;
|
||||||
|
dataset_errors: number;
|
||||||
|
active_jobs: number;
|
||||||
|
queued_jobs: number;
|
||||||
|
network_status: 'none' | 'building' | 'ready';
|
||||||
|
stages: Record<string, any>;
|
||||||
|
}
|
||||||
|
|
||||||
export const hunts = {
|
export const hunts = {
|
||||||
list: (skip = 0, limit = 50) =>
|
list: (skip = 0, limit = 50) =>
|
||||||
api<{ hunts: Hunt[]; total: number }>(`/api/hunts?skip=${skip}&limit=${limit}`),
|
api<{ hunts: Hunt[]; total: number }>(`/api/hunts?skip=${skip}&limit=${limit}`),
|
||||||
@@ -80,6 +94,7 @@ export const hunts = {
|
|||||||
update: (id: string, data: Partial<{ name: string; description: string; status: string }>) =>
|
update: (id: string, data: Partial<{ name: string; description: string; status: string }>) =>
|
||||||
api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||||
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
|
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
|
||||||
|
progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),
|
||||||
};
|
};
|
||||||
|
|
||||||
// -- Datasets --
|
// -- Datasets --
|
||||||
@@ -166,6 +181,8 @@ export interface AssistRequest {
|
|||||||
active_hypotheses?: string[]; annotations_summary?: string;
|
active_hypotheses?: string[]; annotations_summary?: string;
|
||||||
enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate';
|
enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate';
|
||||||
model_override?: string; conversation_id?: string; hunt_id?: string;
|
model_override?: string; conversation_id?: string; hunt_id?: string;
|
||||||
|
execution_preference?: 'auto' | 'force' | 'off';
|
||||||
|
learning_mode?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AssistResponse {
|
export interface AssistResponse {
|
||||||
@@ -174,6 +191,15 @@ export interface AssistResponse {
|
|||||||
sans_references: string[]; model_used: string; node_used: string;
|
sans_references: string[]; model_used: string; node_used: string;
|
||||||
latency_ms: number; perspectives: Record<string, any>[] | null;
|
latency_ms: number; perspectives: Record<string, any>[] | null;
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
|
execution?: {
|
||||||
|
scope: string;
|
||||||
|
datasets_scanned: string[];
|
||||||
|
policy_hits: number;
|
||||||
|
result_count: number;
|
||||||
|
top_domains: string[];
|
||||||
|
top_user_hosts: string[];
|
||||||
|
elapsed_ms: number;
|
||||||
|
} | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface NodeInfo { url: string; available: boolean }
|
export interface NodeInfo { url: string; available: boolean }
|
||||||
@@ -326,10 +352,12 @@ export interface ScanHit {
|
|||||||
theme_name: string; theme_color: string; keyword: string;
|
theme_name: string; theme_color: string; keyword: string;
|
||||||
source_type: string; source_id: string | number; field: string;
|
source_type: string; source_id: string | number; field: string;
|
||||||
matched_value: string; row_index: number | null; dataset_name: string | null;
|
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||||
|
hostname?: string | null; username?: string | null;
|
||||||
}
|
}
|
||||||
export interface ScanResponse {
|
export interface ScanResponse {
|
||||||
total_hits: number; hits: ScanHit[]; themes_scanned: number;
|
total_hits: number; hits: ScanHit[]; themes_scanned: number;
|
||||||
keywords_scanned: number; rows_scanned: number;
|
keywords_scanned: number; rows_scanned: number;
|
||||||
|
cache_used?: boolean; cache_status?: string; cached_at?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const keywords = {
|
export const keywords = {
|
||||||
@@ -363,6 +391,7 @@ export const keywords = {
|
|||||||
scan: (opts: {
|
scan: (opts: {
|
||||||
dataset_ids?: string[]; theme_ids?: string[];
|
dataset_ids?: string[]; theme_ids?: string[];
|
||||||
scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;
|
scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;
|
||||||
|
prefer_cache?: boolean; force_rescan?: boolean;
|
||||||
}) =>
|
}) =>
|
||||||
api<ScanResponse>('/api/keywords/scan', {
|
api<ScanResponse>('/api/keywords/scan', {
|
||||||
method: 'POST', body: JSON.stringify(opts),
|
method: 'POST', body: JSON.stringify(opts),
|
||||||
@@ -579,7 +608,213 @@ export interface HostInventory {
|
|||||||
stats: InventoryStats;
|
stats: InventoryStats;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface InventoryStatus {
|
||||||
|
hunt_id: string;
|
||||||
|
status: 'ready' | 'building' | 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface NetworkSummaryHost {
|
||||||
|
id: string;
|
||||||
|
hostname: string;
|
||||||
|
row_count: number;
|
||||||
|
ip_count: number;
|
||||||
|
user_count: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface NetworkSummary {
|
||||||
|
stats: InventoryStats;
|
||||||
|
top_hosts: NetworkSummaryHost[];
|
||||||
|
top_edges: InventoryConnection[];
|
||||||
|
status?: 'building' | 'deferred';
|
||||||
|
message?: string;
|
||||||
|
}
|
||||||
|
|
||||||
export const network = {
|
export const network = {
|
||||||
hostInventory: (huntId: string) =>
|
hostInventory: (huntId: string, force = false) =>
|
||||||
api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}`),
|
api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),
|
||||||
|
summary: (huntId: string, topN = 20) =>
|
||||||
|
api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),
|
||||||
|
subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {
|
||||||
|
let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;
|
||||||
|
if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;
|
||||||
|
return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);
|
||||||
|
},
|
||||||
|
inventoryStatus: (huntId: string) =>
|
||||||
|
api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),
|
||||||
|
rebuildInventory: (huntId: string) =>
|
||||||
|
api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// -- MITRE ATT&CK Coverage (Feature 1) --
|
||||||
|
|
||||||
|
export interface MitreTechnique {
|
||||||
|
id: string;
|
||||||
|
tactic: string;
|
||||||
|
sources: { type: string; risk_score?: number; hostname?: string; title?: string }[];
|
||||||
|
count: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MitreCoverage {
|
||||||
|
tactics: string[];
|
||||||
|
technique_count: number;
|
||||||
|
detection_count: number;
|
||||||
|
tactic_coverage: Record<string, { techniques: MitreTechnique[]; count: number }>;
|
||||||
|
all_techniques: MitreTechnique[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const mitre = {
|
||||||
|
coverage: (huntId?: string) => {
|
||||||
|
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||||
|
return api<MitreCoverage>(`/api/mitre/coverage${q}`);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// -- Timeline (Feature 2) --
|
||||||
|
|
||||||
|
export interface TimelineEvent {
|
||||||
|
timestamp: string;
|
||||||
|
dataset_id: string;
|
||||||
|
dataset_name: string;
|
||||||
|
artifact_type: string;
|
||||||
|
row_index: number;
|
||||||
|
hostname: string;
|
||||||
|
process: string;
|
||||||
|
summary: string;
|
||||||
|
data: Record<string, string>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TimelineData {
|
||||||
|
hunt_id: string;
|
||||||
|
hunt_name: string;
|
||||||
|
event_count: number;
|
||||||
|
datasets: { id: string; name: string; artifact_type: string; row_count: number }[];
|
||||||
|
events: TimelineEvent[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const timeline = {
|
||||||
|
getHuntTimeline: (huntId: string, limit = 2000) =>
|
||||||
|
api<TimelineData>(`/api/timeline/hunt/${huntId}?limit=${limit}`),
|
||||||
|
};
|
||||||
|
|
||||||
|
// -- Playbooks (Feature 3) --
|
||||||
|
|
||||||
|
export interface PlaybookStep {
|
||||||
|
id: number;
|
||||||
|
order_index: number;
|
||||||
|
title: string;
|
||||||
|
description: string | null;
|
||||||
|
step_type: string;
|
||||||
|
target_route: string | null;
|
||||||
|
is_completed: boolean;
|
||||||
|
completed_at: string | null;
|
||||||
|
notes: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PlaybookSummary {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string | null;
|
||||||
|
is_template: boolean;
|
||||||
|
hunt_id: string | null;
|
||||||
|
status: string;
|
||||||
|
total_steps: number;
|
||||||
|
completed_steps: number;
|
||||||
|
created_at: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PlaybookDetail {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string | null;
|
||||||
|
is_template: boolean;
|
||||||
|
hunt_id: string | null;
|
||||||
|
status: string;
|
||||||
|
created_at: string | null;
|
||||||
|
steps: PlaybookStep[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PlaybookTemplate {
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
steps: { title: string; description: string; step_type: string; target_route: string }[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const playbooks = {
|
||||||
|
list: (huntId?: string) => {
|
||||||
|
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||||
|
return api<{ playbooks: PlaybookSummary[] }>(`/api/playbooks${q}`);
|
||||||
|
},
|
||||||
|
templates: () => api<{ templates: PlaybookTemplate[] }>('/api/playbooks/templates'),
|
||||||
|
get: (id: string) => api<PlaybookDetail>(`/api/playbooks/${id}`),
|
||||||
|
create: (data: { name: string; description?: string; hunt_id?: string; is_template?: boolean; steps?: { title: string; description?: string; step_type?: string; target_route?: string }[] }) =>
|
||||||
|
api<PlaybookDetail>('/api/playbooks', { method: 'POST', body: JSON.stringify(data) }),
|
||||||
|
update: (id: string, data: { name?: string; description?: string; status?: string }) =>
|
||||||
|
api(`/api/playbooks/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||||
|
delete: (id: string) => api(`/api/playbooks/${id}`, { method: 'DELETE' }),
|
||||||
|
updateStep: (stepId: number, data: { is_completed?: boolean; notes?: string }) =>
|
||||||
|
api(`/api/playbooks/steps/${stepId}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||||
|
};
|
||||||
|
|
||||||
|
// -- Saved Searches (Feature 5) --
|
||||||
|
|
||||||
|
export interface SavedSearchData {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string | null;
|
||||||
|
search_type: string;
|
||||||
|
query_params: Record<string, any>;
|
||||||
|
hunt_id?: string | null;
|
||||||
|
threshold: number | null;
|
||||||
|
last_run_at: string | null;
|
||||||
|
last_result_count: number | null;
|
||||||
|
created_at: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SearchRunResult {
|
||||||
|
search_id: string;
|
||||||
|
search_name: string;
|
||||||
|
search_type: string;
|
||||||
|
result_count: number;
|
||||||
|
previous_count: number;
|
||||||
|
delta: number;
|
||||||
|
results: any[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const savedSearches = {
|
||||||
|
list: (searchType?: string) => {
|
||||||
|
const q = searchType ? `?search_type=${encodeURIComponent(searchType)}` : '';
|
||||||
|
return api<{ searches: SavedSearchData[] }>(`/api/searches${q}`);
|
||||||
|
},
|
||||||
|
get: (id: string) => api<SavedSearchData>(`/api/searches/${id}`),
|
||||||
|
create: (data: { name: string; description?: string; search_type: string; query_params: Record<string, any>; threshold?: number; hunt_id?: string }) =>
|
||||||
|
api<SavedSearchData>('/api/searches', { method: 'POST', body: JSON.stringify(data) }),
|
||||||
|
update: (id: string, data: { name?: string; description?: string; search_type?: string; query_params?: Record<string, any>; threshold?: number; hunt_id?: string }) =>
|
||||||
|
api(`/api/searches/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||||
|
delete: (id: string) => api(`/api/searches/${id}`, { method: 'DELETE' }),
|
||||||
|
run: (id: string) => api<SearchRunResult>(`/api/searches/${id}/run`, { method: 'POST' }),
|
||||||
|
};
|
||||||
|
|
||||||
|
// -- STIX Export --
|
||||||
|
|
||||||
|
export const stixExport = {
|
||||||
|
/** Download a STIX 2.1 bundle JSON for a given hunt */
|
||||||
|
download: async (huntId: string): Promise<void> => {
|
||||||
|
const headers: Record<string, string> = {};
|
||||||
|
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
|
||||||
|
const res = await fetch(`${BASE}/api/export/stix/${huntId}`, { headers });
|
||||||
|
if (!res.ok) {
|
||||||
|
const body = await res.json().catch(() => ({}));
|
||||||
|
throw new Error(body.detail || `HTTP ${res.status}`);
|
||||||
|
}
|
||||||
|
const blob = await res.blob();
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = `hunt-${huntId}-stix-bundle.json`;
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -188,11 +188,13 @@ const RESULT_COLUMNS: GridColDef[] = [
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
|
||||||
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
|
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
|
||||||
|
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
|
||||||
|
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
|
||||||
{ field: 'field', headerName: 'Field', width: 130 },
|
{ field: 'field', headerName: 'Field', width: 130 },
|
||||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
|
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||||
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
|
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
|
||||||
];
|
];
|
||||||
|
|
||||||
export default function AUPScanner() {
|
export default function AUPScanner() {
|
||||||
@@ -210,9 +212,9 @@ export default function AUPScanner() {
|
|||||||
// Scan options
|
// Scan options
|
||||||
const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set());
|
const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set());
|
||||||
const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set());
|
const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set());
|
||||||
const [scanHunts, setScanHunts] = useState(true);
|
const [scanHunts, setScanHunts] = useState(false);
|
||||||
const [scanAnnotations, setScanAnnotations] = useState(true);
|
const [scanAnnotations, setScanAnnotations] = useState(false);
|
||||||
const [scanMessages, setScanMessages] = useState(true);
|
const [scanMessages, setScanMessages] = useState(false);
|
||||||
|
|
||||||
// Load themes + hunts
|
// Load themes + hunts
|
||||||
const loadData = useCallback(async () => {
|
const loadData = useCallback(async () => {
|
||||||
@@ -224,9 +226,13 @@ export default function AUPScanner() {
|
|||||||
]);
|
]);
|
||||||
setThemes(tRes.themes);
|
setThemes(tRes.themes);
|
||||||
setHuntList(hRes.hunts);
|
setHuntList(hRes.hunts);
|
||||||
|
if (!selectedHuntId && hRes.hunts.length > 0) {
|
||||||
|
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
|
||||||
|
setSelectedHuntId(best.id);
|
||||||
|
}
|
||||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}, [enqueueSnackbar]);
|
}, [enqueueSnackbar, selectedHuntId]);
|
||||||
|
|
||||||
useEffect(() => { loadData(); }, [loadData]);
|
useEffect(() => { loadData(); }, [loadData]);
|
||||||
|
|
||||||
@@ -237,7 +243,7 @@ export default function AUPScanner() {
|
|||||||
datasets.list(0, 500, selectedHuntId).then(res => {
|
datasets.list(0, 500, selectedHuntId).then(res => {
|
||||||
if (cancelled) return;
|
if (cancelled) return;
|
||||||
setDsList(res.datasets);
|
setDsList(res.datasets);
|
||||||
setSelectedDs(new Set(res.datasets.map(d => d.id)));
|
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
|
||||||
}).catch(() => {});
|
}).catch(() => {});
|
||||||
return () => { cancelled = true; };
|
return () => { cancelled = true; };
|
||||||
}, [selectedHuntId]);
|
}, [selectedHuntId]);
|
||||||
@@ -251,6 +257,15 @@ export default function AUPScanner() {
|
|||||||
|
|
||||||
// Run scan
|
// Run scan
|
||||||
const runScan = useCallback(async () => {
|
const runScan = useCallback(async () => {
|
||||||
|
if (!selectedHuntId) {
|
||||||
|
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (selectedDs.size === 0) {
|
||||||
|
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
setScanning(true);
|
setScanning(true);
|
||||||
setScanResult(null);
|
setScanResult(null);
|
||||||
try {
|
try {
|
||||||
@@ -260,6 +275,7 @@ export default function AUPScanner() {
|
|||||||
scan_hunts: scanHunts,
|
scan_hunts: scanHunts,
|
||||||
scan_annotations: scanAnnotations,
|
scan_annotations: scanAnnotations,
|
||||||
scan_messages: scanMessages,
|
scan_messages: scanMessages,
|
||||||
|
prefer_cache: true,
|
||||||
});
|
});
|
||||||
setScanResult(res);
|
setScanResult(res);
|
||||||
enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, {
|
enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, {
|
||||||
@@ -267,7 +283,7 @@ export default function AUPScanner() {
|
|||||||
});
|
});
|
||||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||||
setScanning(false);
|
setScanning(false);
|
||||||
}, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
}, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||||
|
|
||||||
if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>;
|
if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>;
|
||||||
|
|
||||||
@@ -316,9 +332,38 @@ export default function AUPScanner() {
|
|||||||
)}
|
)}
|
||||||
{!selectedHuntId && (
|
{!selectedHuntId && (
|
||||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||||
All datasets will be scanned if no hunt is selected
|
Select a hunt to enable scoped scanning
|
||||||
</Typography>
|
</Typography>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
|
||||||
|
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
|
||||||
|
<Select
|
||||||
|
labelId="aup-dataset-label"
|
||||||
|
multiple
|
||||||
|
value={Array.from(selectedDs)}
|
||||||
|
label="Datasets"
|
||||||
|
renderValue={(selected) => `${(selected as string[]).length} selected`}
|
||||||
|
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
|
||||||
|
>
|
||||||
|
{dsList.map(d => (
|
||||||
|
<MenuItem key={d.id} value={d.id}>
|
||||||
|
<Checkbox size="small" checked={selectedDs.has(d.id)} />
|
||||||
|
<Typography variant="body2" sx={{ ml: 0.5 }}>
|
||||||
|
{d.name} ({d.row_count.toLocaleString()} rows)
|
||||||
|
</Typography>
|
||||||
|
</MenuItem>
|
||||||
|
))}
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
{selectedHuntId && dsList.length > 0 && (
|
||||||
|
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
|
||||||
|
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
|
||||||
|
</Stack>
|
||||||
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
{/* Theme selector */}
|
{/* Theme selector */}
|
||||||
@@ -372,7 +417,7 @@ export default function AUPScanner() {
|
|||||||
<Button
|
<Button
|
||||||
variant="contained" color="warning" size="large"
|
variant="contained" color="warning" size="large"
|
||||||
startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />}
|
startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />}
|
||||||
onClick={runScan} disabled={scanning}
|
onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
|
||||||
>
|
>
|
||||||
{scanning ? 'Scanning…' : 'Run Scan'}
|
{scanning ? 'Scanning…' : 'Run Scan'}
|
||||||
</Button>
|
</Button>
|
||||||
@@ -392,6 +437,15 @@ export default function AUPScanner() {
|
|||||||
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||||
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||||
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||||
|
{scanResult.cache_status && (
|
||||||
|
<Chip
|
||||||
|
size="small"
|
||||||
|
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
|
||||||
|
sx={{ ml: 1, height: 20 }}
|
||||||
|
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
|
||||||
|
variant="outlined"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</Alert>
|
</Alert>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/**
|
/**
|
||||||
* AgentPanel — analyst-assist chat with quick / deep / debate modes,
|
* AgentPanel - analyst-assist chat with quick / deep / debate modes,
|
||||||
* streaming support, SANS references, and conversation persistence.
|
* SSE streaming, SANS references, and conversation persistence.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
||||||
@@ -8,7 +8,7 @@ import {
|
|||||||
Box, Typography, Paper, TextField, Button, Stack, Chip,
|
Box, Typography, Paper, TextField, Button, Stack, Chip,
|
||||||
ToggleButtonGroup, ToggleButton, CircularProgress, Alert,
|
ToggleButtonGroup, ToggleButton, CircularProgress, Alert,
|
||||||
Accordion, AccordionSummary, AccordionDetails, Divider, Select,
|
Accordion, AccordionSummary, AccordionDetails, Divider, Select,
|
||||||
MenuItem, FormControl, InputLabel, LinearProgress,
|
MenuItem, FormControl, InputLabel, LinearProgress, FormControlLabel, Switch,
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
import SendIcon from '@mui/icons-material/Send';
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||||
@@ -16,26 +16,31 @@ import SchoolIcon from '@mui/icons-material/School';
|
|||||||
import PsychologyIcon from '@mui/icons-material/Psychology';
|
import PsychologyIcon from '@mui/icons-material/Psychology';
|
||||||
import ForumIcon from '@mui/icons-material/Forum';
|
import ForumIcon from '@mui/icons-material/Forum';
|
||||||
import SpeedIcon from '@mui/icons-material/Speed';
|
import SpeedIcon from '@mui/icons-material/Speed';
|
||||||
|
import StopIcon from '@mui/icons-material/Stop';
|
||||||
import { useSnackbar } from 'notistack';
|
import { useSnackbar } from 'notistack';
|
||||||
import {
|
import {
|
||||||
agent, datasets, hunts, type AssistRequest, type AssistResponse,
|
agent, datasets, hunts, type AssistRequest, type AssistResponse,
|
||||||
type DatasetSummary, type Hunt,
|
type DatasetSummary, type Hunt,
|
||||||
} from '../api/client';
|
} from '../api/client';
|
||||||
|
|
||||||
interface Message { role: 'user' | 'assistant'; content: string; meta?: AssistResponse }
|
interface Message { role: 'user' | 'assistant'; content: string; meta?: AssistResponse; streaming?: boolean }
|
||||||
|
|
||||||
export default function AgentPanel() {
|
export default function AgentPanel() {
|
||||||
const { enqueueSnackbar } = useSnackbar();
|
const { enqueueSnackbar } = useSnackbar();
|
||||||
const [messages, setMessages] = useState<Message[]>([]);
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
const [query, setQuery] = useState('');
|
const [query, setQuery] = useState('');
|
||||||
const [mode, setMode] = useState<'quick' | 'deep' | 'debate'>('quick');
|
const [mode, setMode] = useState<'quick' | 'deep' | 'debate'>('quick');
|
||||||
|
const [executionPreference, setExecutionPreference] = useState<'auto' | 'force' | 'off'>('auto');
|
||||||
|
const [learningMode, setLearningMode] = useState(false);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [streaming, setStreaming] = useState(false);
|
||||||
const [conversationId, setConversationId] = useState<string | null>(null);
|
const [conversationId, setConversationId] = useState<string | null>(null);
|
||||||
const [datasetList, setDatasets] = useState<DatasetSummary[]>([]);
|
const [datasetList, setDatasets] = useState<DatasetSummary[]>([]);
|
||||||
const [huntList, setHunts] = useState<Hunt[]>([]);
|
const [huntList, setHunts] = useState<Hunt[]>([]);
|
||||||
const [selectedDataset, setSelectedDataset] = useState('');
|
const [selectedDataset, setSelectedDataset] = useState('');
|
||||||
const [selectedHunt, setSelectedHunt] = useState('');
|
const [selectedHunt, setSelectedHunt] = useState('');
|
||||||
const bottomRef = useRef<HTMLDivElement>(null);
|
const bottomRef = useRef<HTMLDivElement>(null);
|
||||||
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
datasets.list(0, 100).then(r => setDatasets(r.datasets)).catch(() => {});
|
datasets.list(0, 100).then(r => setDatasets(r.datasets)).catch(() => {});
|
||||||
@@ -44,6 +49,12 @@ export default function AgentPanel() {
|
|||||||
|
|
||||||
useEffect(() => { bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]);
|
useEffect(() => { bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]);
|
||||||
|
|
||||||
|
const stopStreaming = () => {
|
||||||
|
abortRef.current?.abort();
|
||||||
|
setStreaming(false);
|
||||||
|
setLoading(false);
|
||||||
|
};
|
||||||
|
|
||||||
const send = useCallback(async () => {
|
const send = useCallback(async () => {
|
||||||
if (!query.trim() || loading) return;
|
if (!query.trim() || loading) return;
|
||||||
const userMsg: Message = { role: 'user', content: query };
|
const userMsg: Message = { role: 'user', content: query };
|
||||||
@@ -59,8 +70,93 @@ export default function AgentPanel() {
|
|||||||
hunt_id: selectedHunt || undefined,
|
hunt_id: selectedHunt || undefined,
|
||||||
dataset_name: ds?.name,
|
dataset_name: ds?.name,
|
||||||
data_summary: ds ? `${ds.row_count} rows, columns: ${Object.keys(ds.column_schema || {}).join(', ')}` : undefined,
|
data_summary: ds ? `${ds.row_count} rows, columns: ${Object.keys(ds.column_schema || {}).join(', ')}` : undefined,
|
||||||
|
execution_preference: executionPreference,
|
||||||
|
learning_mode: learningMode,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Try SSE streaming first, fall back to regular request
|
||||||
|
try {
|
||||||
|
const controller = new AbortController();
|
||||||
|
abortRef.current = controller;
|
||||||
|
setStreaming(true);
|
||||||
|
|
||||||
|
const res = await agent.assistStream(req);
|
||||||
|
if (!res.ok || !res.body) throw new Error('Stream unavailable');
|
||||||
|
|
||||||
|
setMessages(prev => [...prev, { role: 'assistant', content: '', streaming: true }]);
|
||||||
|
|
||||||
|
const reader = res.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let fullText = '';
|
||||||
|
let metaData: AssistResponse | undefined;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (controller.signal.aborted) break;
|
||||||
|
const { done, value } = await reader.read();
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
const chunk = decoder.decode(value, { stream: true });
|
||||||
|
// Parse SSE lines
|
||||||
|
for (const line of chunk.split('\n')) {
|
||||||
|
if (line.startsWith('data: ')) {
|
||||||
|
const data = line.slice(6);
|
||||||
|
if (data === '[DONE]') continue;
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(data);
|
||||||
|
if (parsed.token) {
|
||||||
|
fullText += parsed.token;
|
||||||
|
const nextText = fullText;
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev];
|
||||||
|
const last = updated[updated.length - 1];
|
||||||
|
if (last?.role === 'assistant') {
|
||||||
|
updated[updated.length - 1] = { ...last, content: nextText };
|
||||||
|
}
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (parsed.meta || parsed.confidence) {
|
||||||
|
metaData = parsed.meta || parsed;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Non-JSON data line, treat as text token
|
||||||
|
fullText += data;
|
||||||
|
const nextText = fullText;
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev];
|
||||||
|
const last = updated[updated.length - 1];
|
||||||
|
if (last?.role === 'assistant') {
|
||||||
|
updated[updated.length - 1] = { ...last, content: nextText };
|
||||||
|
}
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the streamed message
|
||||||
|
setMessages(prev => {
|
||||||
|
const updated = [...prev];
|
||||||
|
const last = updated[updated.length - 1];
|
||||||
|
if (last?.role === 'assistant') {
|
||||||
|
updated[updated.length - 1] = { ...last, content: fullText || 'No response received.', streaming: false, meta: metaData };
|
||||||
|
}
|
||||||
|
return updated;
|
||||||
|
});
|
||||||
|
if (metaData?.conversation_id) setConversationId(metaData.conversation_id);
|
||||||
|
|
||||||
|
} catch (streamErr: any) {
|
||||||
|
// Streaming failed or unavailable, fall back to regular request
|
||||||
|
setStreaming(false);
|
||||||
|
// Remove the empty streaming message if one was added
|
||||||
|
setMessages(prev => {
|
||||||
|
if (prev.length > 0 && prev[prev.length - 1].streaming && prev[prev.length - 1].content === '') {
|
||||||
|
return prev.slice(0, -1);
|
||||||
|
}
|
||||||
|
return prev;
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const resp = await agent.assist(req);
|
const resp = await agent.assist(req);
|
||||||
setConversationId(resp.conversation_id || null);
|
setConversationId(resp.conversation_id || null);
|
||||||
@@ -69,8 +165,23 @@ export default function AgentPanel() {
|
|||||||
enqueueSnackbar(e.message, { variant: 'error' });
|
enqueueSnackbar(e.message, { variant: 'error' });
|
||||||
setMessages(prev => [...prev, { role: 'assistant', content: `Error: ${e.message}` }]);
|
setMessages(prev => [...prev, { role: 'assistant', content: `Error: ${e.message}` }]);
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}, [query, mode, loading, conversationId, selectedDataset, selectedHunt, datasetList, enqueueSnackbar]);
|
setStreaming(false);
|
||||||
|
abortRef.current = null;
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
query,
|
||||||
|
mode,
|
||||||
|
executionPreference,
|
||||||
|
learningMode,
|
||||||
|
loading,
|
||||||
|
conversationId,
|
||||||
|
selectedDataset,
|
||||||
|
selectedHunt,
|
||||||
|
datasetList,
|
||||||
|
enqueueSnackbar,
|
||||||
|
]);
|
||||||
|
|
||||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||||
if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send(); }
|
if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send(); }
|
||||||
@@ -112,6 +223,25 @@ export default function AgentPanel() {
|
|||||||
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
||||||
</Select>
|
</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControl size="small" sx={{ minWidth: 180 }}>
|
||||||
|
<InputLabel>Execution</InputLabel>
|
||||||
|
<Select
|
||||||
|
label="Execution"
|
||||||
|
value={executionPreference}
|
||||||
|
onChange={e => setExecutionPreference(e.target.value as 'auto' | 'force' | 'off')}
|
||||||
|
>
|
||||||
|
<MenuItem value="auto">Auto</MenuItem>
|
||||||
|
<MenuItem value="force">Force execute</MenuItem>
|
||||||
|
<MenuItem value="off">Advisory only</MenuItem>
|
||||||
|
</Select>
|
||||||
|
</FormControl>
|
||||||
|
|
||||||
|
<FormControlLabel
|
||||||
|
control={<Switch checked={learningMode} onChange={(_, v) => setLearningMode(v)} size="small" />}
|
||||||
|
label={<Typography variant="caption">Learning mode</Typography>}
|
||||||
|
sx={{ ml: 0.5 }}
|
||||||
|
/>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Paper>
|
</Paper>
|
||||||
|
|
||||||
@@ -124,7 +254,7 @@ export default function AgentPanel() {
|
|||||||
Ask a question about your threat hunt data.
|
Ask a question about your threat hunt data.
|
||||||
</Typography>
|
</Typography>
|
||||||
<Typography variant="caption" color="text.secondary">
|
<Typography variant="caption" color="text.secondary">
|
||||||
The agent provides advisory guidance — all decisions remain with the analyst.
|
Agent can provide advisory guidance or execute policy scans based on execution mode.
|
||||||
</Typography>
|
</Typography>
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
@@ -132,20 +262,24 @@ export default function AgentPanel() {
|
|||||||
<Box key={i} sx={{ mb: 2 }}>
|
<Box key={i} sx={{ mb: 2 }}>
|
||||||
<Typography variant="caption" color="text.secondary" fontWeight={700}>
|
<Typography variant="caption" color="text.secondary" fontWeight={700}>
|
||||||
{m.role === 'user' ? 'You' : 'Agent'}
|
{m.role === 'user' ? 'You' : 'Agent'}
|
||||||
|
{m.streaming && <Chip label="streaming" size="small" color="info" sx={{ ml: 1, height: 16, fontSize: '0.65rem' }} />}
|
||||||
</Typography>
|
</Typography>
|
||||||
<Paper sx={{
|
<Paper sx={{
|
||||||
p: 1.5, mt: 0.5,
|
p: 1.5, mt: 0.5,
|
||||||
bgcolor: m.role === 'user' ? 'primary.dark' : 'background.default',
|
bgcolor: m.role === 'user' ? 'primary.dark' : 'background.default',
|
||||||
borderColor: m.role === 'user' ? 'primary.main' : 'divider',
|
borderColor: m.role === 'user' ? 'primary.main' : 'divider',
|
||||||
}}>
|
}}>
|
||||||
<Typography variant="body2" sx={{ whiteSpace: 'pre-wrap' }}>{m.content}</Typography>
|
<Typography variant="body2" sx={{ whiteSpace: 'pre-wrap' }}>
|
||||||
|
{m.content}
|
||||||
|
{m.streaming && <span className="cursor-blink">|</span>}
|
||||||
|
</Typography>
|
||||||
</Paper>
|
</Paper>
|
||||||
|
|
||||||
{/* Response metadata */}
|
{/* Response metadata */}
|
||||||
{m.meta && (
|
{m.meta && (
|
||||||
<Box sx={{ mt: 0.5 }}>
|
<Box sx={{ mt: 0.5 }}>
|
||||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 0.5 }}>
|
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 0.5 }}>
|
||||||
<Chip label={`${m.meta.confidence * 100}% confidence`} size="small"
|
<Chip label={`${Math.round(m.meta.confidence * 100)}% confidence`} size="small"
|
||||||
color={m.meta.confidence >= 0.7 ? 'success' : m.meta.confidence >= 0.4 ? 'warning' : 'error'} variant="outlined" />
|
color={m.meta.confidence >= 0.7 ? 'success' : m.meta.confidence >= 0.4 ? 'warning' : 'error'} variant="outlined" />
|
||||||
<Chip label={m.meta.model_used} size="small" variant="outlined" />
|
<Chip label={m.meta.model_used} size="small" variant="outlined" />
|
||||||
<Chip label={m.meta.node_used} size="small" variant="outlined" />
|
<Chip label={m.meta.node_used} size="small" variant="outlined" />
|
||||||
@@ -190,7 +324,7 @@ export default function AgentPanel() {
|
|||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails>
|
<AccordionDetails>
|
||||||
{m.meta.sans_references.map((r, j) => (
|
{m.meta.sans_references.map((r, j) => (
|
||||||
<Typography key={j} variant="body2" sx={{ mb: 0.5 }}>• {r}</Typography>
|
<Typography key={j} variant="body2" sx={{ mb: 0.5 }}>{r}</Typography>
|
||||||
))}
|
))}
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
@@ -214,6 +348,32 @@ export default function AgentPanel() {
|
|||||||
</Accordion>
|
</Accordion>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* Execution summary */}
|
||||||
|
{m.meta.execution && (
|
||||||
|
<Accordion disableGutters sx={{ mt: 0.5 }}>
|
||||||
|
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||||
|
<Typography variant="caption">
|
||||||
|
Execution Results ({m.meta.execution.policy_hits} hits in {m.meta.execution.elapsed_ms}ms)
|
||||||
|
</Typography>
|
||||||
|
</AccordionSummary>
|
||||||
|
<AccordionDetails>
|
||||||
|
<Typography variant="body2" sx={{ mb: 0.5 }}>
|
||||||
|
Scope: {m.meta.execution.scope}
|
||||||
|
</Typography>
|
||||||
|
<Typography variant="body2" sx={{ mb: 0.5 }}>
|
||||||
|
Datasets: {m.meta.execution.datasets_scanned.join(', ') || 'None'}
|
||||||
|
</Typography>
|
||||||
|
{m.meta.execution.top_domains.length > 0 && (
|
||||||
|
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mt: 0.5 }}>
|
||||||
|
{m.meta.execution.top_domains.map((d, j) => (
|
||||||
|
<Chip key={j} label={d} size="small" color="success" variant="outlined" />
|
||||||
|
))}
|
||||||
|
</Stack>
|
||||||
|
)}
|
||||||
|
</AccordionDetails>
|
||||||
|
</Accordion>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Caveats */}
|
{/* Caveats */}
|
||||||
{m.meta.caveats && (
|
{m.meta.caveats && (
|
||||||
<Alert severity="warning" sx={{ mt: 0.5, py: 0 }}>
|
<Alert severity="warning" sx={{ mt: 0.5, py: 0 }}>
|
||||||
@@ -224,7 +384,7 @@ export default function AgentPanel() {
|
|||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
))}
|
))}
|
||||||
{loading && <LinearProgress sx={{ mb: 1 }} />}
|
{loading && !streaming && <LinearProgress sx={{ mb: 1 }} />}
|
||||||
<div ref={bottomRef} />
|
<div ref={bottomRef} />
|
||||||
</Paper>
|
</Paper>
|
||||||
|
|
||||||
@@ -237,10 +397,17 @@ export default function AgentPanel() {
|
|||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
disabled={loading}
|
disabled={loading}
|
||||||
/>
|
/>
|
||||||
|
{streaming ? (
|
||||||
|
<Button variant="outlined" color="error" onClick={stopStreaming}>
|
||||||
|
<StopIcon />
|
||||||
|
</Button>
|
||||||
|
) : (
|
||||||
<Button variant="contained" onClick={send} disabled={loading || !query.trim()}>
|
<Button variant="contained" onClick={send} disabled={loading || !query.trim()}>
|
||||||
{loading ? <CircularProgress size={20} /> : <SendIcon />}
|
{loading ? <CircularProgress size={20} /> : <SendIcon />}
|
||||||
</Button>
|
</Button>
|
||||||
|
)}
|
||||||
</Stack>
|
</Stack>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user