mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
Compare commits
4 Commits
7c454036c7
...
483176c06b
| Author | SHA1 | Date | |
|---|---|---|---|
| 483176c06b | |||
| 13bd9ec9e0 | |||
| 5a2ad8ec1c | |||
| 37a9584d0c |
26
.gitignore
vendored
26
.gitignore
vendored
@@ -1,4 +1,4 @@
|
||||
# ── Python ────────────────────────────────────
|
||||
# ── Python ────────────────────────────────────
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
@@ -8,34 +8,34 @@ build/
|
||||
*.egg
|
||||
.eggs/
|
||||
|
||||
# ── Virtual environments ─────────────────────
|
||||
# ── Virtual environments ─────────────────────
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
|
||||
# ── IDE / Editor ─────────────────────────────
|
||||
# ── IDE / Editor ─────────────────────────────
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# ── OS ────────────────────────────────────────
|
||||
# ── OS ────────────────────────────────────────
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# ── Environment / Secrets ────────────────────
|
||||
# ── Environment / Secrets ────────────────────
|
||||
.env
|
||||
*.env.local
|
||||
|
||||
# ── Database ─────────────────────────────────
|
||||
# ── Database ─────────────────────────────────
|
||||
*.db
|
||||
*.sqlite3
|
||||
|
||||
# ── Uploads ──────────────────────────────────
|
||||
# ── Uploads ──────────────────────────────────
|
||||
uploads/
|
||||
|
||||
# ── Node / Frontend ──────────────────────────
|
||||
# ── Node / Frontend ──────────────────────────
|
||||
node_modules/
|
||||
frontend/build/
|
||||
frontend/.env.local
|
||||
@@ -43,14 +43,18 @@ npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# ── Docker ───────────────────────────────────
|
||||
# ── Docker ───────────────────────────────────
|
||||
docker-compose.override.yml
|
||||
|
||||
# ── Test / Coverage ──────────────────────────
|
||||
# ── Test / Coverage ──────────────────────────
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# ── Alembic ──────────────────────────────────
|
||||
# ── Alembic ──────────────────────────────────
|
||||
alembic/versions/*.pyc
|
||||
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ COPY frontend/tsconfig.json ./
|
||||
# Build application
|
||||
RUN npm run build
|
||||
|
||||
# Production stage — nginx reverse-proxy + static files
|
||||
# Production stage — nginx reverse-proxy + static files
|
||||
FROM nginx:alpine
|
||||
|
||||
# Copy built React app
|
||||
|
||||
148
_add_label_filter_networkmap.py
Normal file
148
_add_label_filter_networkmap.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# 1) Add label mode type near graph types
|
||||
marker="interface GEdge { source: string; target: string; weight: number }\ninterface Graph { nodes: GNode[]; edges: GEdge[] }\n"
|
||||
if marker in t and "type LabelMode" not in t:
|
||||
t=t.replace(marker, marker+"\ntype LabelMode = 'all' | 'highlight' | 'none';\n")
|
||||
|
||||
# 2) extend drawLabels signature
|
||||
old_sig="""function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
simplify: boolean,
|
||||
) {
|
||||
"""
|
||||
new_sig="""function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
simplify: boolean, labelMode: LabelMode,
|
||||
) {
|
||||
"""
|
||||
if old_sig in t:
|
||||
t=t.replace(old_sig,new_sig)
|
||||
|
||||
# 3) label mode guards inside drawLabels
|
||||
old_guard=""" const dimmed = search.length > 0;
|
||||
if (simplify && !search && !hovered && !selected) {
|
||||
return;
|
||||
}
|
||||
"""
|
||||
new_guard=""" if (labelMode === 'none') return;
|
||||
const dimmed = search.length > 0;
|
||||
if (labelMode === 'highlight' && !search && !hovered && !selected) return;
|
||||
if (simplify && labelMode !== 'all' && !search && !hovered && !selected) {
|
||||
return;
|
||||
}
|
||||
"""
|
||||
if old_guard in t:
|
||||
t=t.replace(old_guard,new_guard)
|
||||
|
||||
old_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
|
||||
const show = isHighlight || n.meta.type === 'host' || n.count >= 2;
|
||||
if (!show) continue;
|
||||
"""
|
||||
new_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
|
||||
const show = labelMode === 'all'
|
||||
? (isHighlight || n.meta.type === 'host' || n.count >= 2)
|
||||
: isHighlight;
|
||||
if (!show) continue;
|
||||
"""
|
||||
if old_show in t:
|
||||
t=t.replace(old_show,new_show)
|
||||
|
||||
# 4) drawGraph signature and call site
|
||||
old_graph_sig="""function drawGraph(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null, search: string,
|
||||
vp: Viewport, animTime: number, dpr: number,
|
||||
) {
|
||||
"""
|
||||
new_graph_sig="""function drawGraph(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null, search: string,
|
||||
vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode,
|
||||
) {
|
||||
"""
|
||||
if old_graph_sig in t:
|
||||
t=t.replace(old_graph_sig,new_graph_sig)
|
||||
|
||||
old_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);"
|
||||
new_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode);"
|
||||
if old_drawlabels_call in t:
|
||||
t=t.replace(old_drawlabels_call,new_drawlabels_call)
|
||||
|
||||
# 5) state for label mode
|
||||
state_anchor=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n"
|
||||
state_new=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n const [labelMode, setLabelMode] = useState<LabelMode>('highlight');\n"
|
||||
if state_anchor in t:
|
||||
t=t.replace(state_anchor,state_new)
|
||||
|
||||
# 6) pass labelMode in draw calls
|
||||
old_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);"
|
||||
new_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode);"
|
||||
if old_tick_draw in t:
|
||||
t=t.replace(old_tick_draw,new_tick_draw)
|
||||
|
||||
old_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);"
|
||||
new_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode);"
|
||||
if old_redraw_draw in t:
|
||||
t=t.replace(old_redraw_draw,new_redraw_draw)
|
||||
|
||||
# 7) include labelMode in redraw deps
|
||||
old_redraw_dep="] , [graph, hovered, selectedNode, search]);"
|
||||
if old_redraw_dep in t:
|
||||
t=t.replace(old_redraw_dep, "] , [graph, hovered, selectedNode, search, labelMode]);")
|
||||
else:
|
||||
t=t.replace(" }, [graph, hovered, selectedNode, search]);"," }, [graph, hovered, selectedNode, search, labelMode]);")
|
||||
|
||||
# 8) Add toolbar selector after search field
|
||||
search_block=""" <TextField
|
||||
size="small"
|
||||
placeholder="Search hosts, IPs, users\u2026"
|
||||
value={search}
|
||||
onChange={e => setSearch(e.target.value)}
|
||||
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||
slotProps={{
|
||||
input: {
|
||||
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
"""
|
||||
label_block=""" <TextField
|
||||
size="small"
|
||||
placeholder="Search hosts, IPs, users\u2026"
|
||||
value={search}
|
||||
onChange={e => setSearch(e.target.value)}
|
||||
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||
slotProps={{
|
||||
input: {
|
||||
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
|
||||
<FormControl size="small" sx={{ minWidth: 140 }}>
|
||||
<InputLabel id="label-mode-selector">Labels</InputLabel>
|
||||
<Select
|
||||
labelId="label-mode-selector"
|
||||
value={labelMode}
|
||||
label="Labels"
|
||||
onChange={e => setLabelMode(e.target.value as LabelMode)}
|
||||
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
|
||||
>
|
||||
<MenuItem value="none">None</MenuItem>
|
||||
<MenuItem value="highlight">Selected/Search</MenuItem>
|
||||
<MenuItem value="all">All</MenuItem>
|
||||
</Select>
|
||||
</FormControl>
|
||||
"""
|
||||
if search_block in t:
|
||||
t=t.replace(search_block,label_block)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added network map label filter control and renderer modes')
|
||||
18
_add_scanner_budget_config.py
Normal file
18
_add_scanner_budget_config.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' # -- Scanner settings -----------------------------------------------
|
||||
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||
'''
|
||||
new=''' # -- Scanner settings -----------------------------------------------
|
||||
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||
default=300000,
|
||||
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||
)
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('scanner settings block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added SCANNER_MAX_ROWS_PER_SCAN config')
|
||||
46
_apply_frontend_scale_patch.py
Normal file
46
_apply_frontend_scale_patch.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
|
||||
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||
|
||||
# -------- client.ts --------
|
||||
client = root / "frontend/src/api/client.ts"
|
||||
text = client.read_text(encoding="utf-8")
|
||||
|
||||
if "export interface NetworkSummary" not in text:
|
||||
insert_after = "export interface InventoryStatus {\n hunt_id: string;\n status: 'ready' | 'building' | 'none';\n}\n"
|
||||
addition = insert_after + "\nexport interface NetworkSummaryHost {\n id: string;\n hostname: string;\n row_count: number;\n ip_count: number;\n user_count: number;\n}\n\nexport interface NetworkSummary {\n stats: InventoryStats;\n top_hosts: NetworkSummaryHost[];\n top_edges: InventoryConnection[];\n status?: 'building' | 'deferred';\n message?: string;\n}\n"
|
||||
text = text.replace(insert_after, addition)
|
||||
|
||||
net_old = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
|
||||
net_new = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n summary: (huntId: string, topN = 20) =>\n api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),\n subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {\n let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;\n if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;\n return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);\n },\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
|
||||
if net_old in text:
|
||||
text = text.replace(net_old, net_new)
|
||||
|
||||
client.write_text(text, encoding="utf-8")
|
||||
|
||||
# -------- NetworkMap.tsx --------
|
||||
nm = root / "frontend/src/components/NetworkMap.tsx"
|
||||
text = nm.read_text(encoding="utf-8")
|
||||
|
||||
# add constants
|
||||
if "LARGE_HUNT_HOST_THRESHOLD" not in text:
|
||||
text = text.replace("let lastSelectedHuntId = '';\n", "let lastSelectedHuntId = '';\nconst LARGE_HUNT_HOST_THRESHOLD = 400;\nconst LARGE_HUNT_SUBGRAPH_HOSTS = 350;\nconst LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n")
|
||||
|
||||
# inject helper in component after sleep
|
||||
marker = " const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
|
||||
if "loadScaleAwareGraph" not in text:
|
||||
helper = marker + "\n const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n setLoading(true); setError(''); setGraph(null); setStats(null);\n setSelectedNode(null); setPopoverAnchor(null);\n\n const waitReadyThen = async <T,>(fn: () => Promise<T>): Promise<T> => {\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const out: any = await fn();\n if (out && !out.status) return out as T;\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') {\n const out2: any = await fn();\n if (out2 && !out2.status) return out2 as T;\n }\n if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes');\n const jitter = Math.floor(Math.random() * 250);\n await sleep(delayMs + jitter);\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n };\n\n try {\n setProgress('Loading network summary');\n const summary: any = await waitReadyThen(() => network.summary(huntId, 20));\n const totalHosts = summary?.stats?.total_hosts || 0;\n\n if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) {\n setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`);\n const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES));\n if (!sub?.hosts || sub.hosts.length === 0) {\n setError('No hosts found for subgraph.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() });\n setGraph(g);\n return;\n }\n\n // Small/medium hunts: load full inventory\n setProgress('Loading host inventory');\n const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh));\n if (!inv?.hosts || inv.hosts.length === 0) {\n setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats || inv.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() });\n setGraph(g);\n } catch (e: any) {\n console.error('[NetworkMap] scale-aware load error:', e);\n setError(e.message || 'Failed to load network data');\n } finally {\n setLoading(false);\n setProgress('');\n }\n }, []);\n"
|
||||
text = text.replace(marker, helper)
|
||||
|
||||
# simplify existing loadGraph function body to delegate
|
||||
pattern_start = text.find(" // Load host inventory for selected hunt (with cache).")
|
||||
if pattern_start != -1:
|
||||
# replace the whole loadGraph useCallback block by simple delegator
|
||||
import re
|
||||
block_re = re.compile(r" // Load host inventory for selected hunt \(with cache\)\.[\s\S]*?\n \}, \[\]\); // Stable - reads canvasSizeRef, no state deps\n", re.M)
|
||||
repl = " // Load graph data for selected hunt (delegates to scale-aware loader).\n const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n if (!huntId) return;\n\n // Check module-level cache first (5 min TTL)\n if (!forceRefresh) {\n const cached = graphCache.get(huntId);\n if (cached && Date.now() - cached.ts < 5 * 60 * 1000) {\n setGraph(cached.graph);\n setStats(cached.stats);\n setError('');\n simAlphaRef.current = 0;\n return;\n }\n }\n\n await loadScaleAwareGraph(huntId, forceRefresh);\n // eslint-disable-next-line react-hooks/exhaustive-deps\n }, []); // Stable - reads canvasSizeRef, no state deps\n"
|
||||
text = block_re.sub(repl, text, count=1)
|
||||
|
||||
nm.write_text(text, encoding="utf-8")
|
||||
|
||||
print("Patched frontend client + NetworkMap for scale-aware loading")
|
||||
206
_apply_phase1_patch.py
Normal file
206
_apply_phase1_patch.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from pathlib import Path
|
||||
|
||||
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||
|
||||
# 1) config.py additions
|
||||
cfg = root / "backend/app/config.py"
|
||||
text = cfg.read_text(encoding="utf-8")
|
||||
needle = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n"
|
||||
insert = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n\n # -- Job queue settings ----------------------------------------------\n JOB_QUEUE_MAX_BACKLOG: int = Field(\n default=2000, description=\"Soft cap for queued background jobs\"\n )\n JOB_QUEUE_RETAIN_COMPLETED: int = Field(\n default=3000, description=\"Maximum completed/failed jobs to retain in memory\"\n )\n JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(\n default=60, description=\"How often to run in-memory job cleanup\"\n )\n JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
|
||||
if needle in text:
|
||||
text = text.replace(needle, insert)
|
||||
cfg.write_text(text, encoding="utf-8")
|
||||
|
||||
# 2) scanner.py default scope = dataset-only
|
||||
scanner = root / "backend/app/services/scanner.py"
|
||||
text = scanner.read_text(encoding="utf-8")
|
||||
text = text.replace(" scan_hunts: bool = True,", " scan_hunts: bool = False,")
|
||||
text = text.replace(" scan_annotations: bool = True,", " scan_annotations: bool = False,")
|
||||
text = text.replace(" scan_messages: bool = True,", " scan_messages: bool = False,")
|
||||
scanner.write_text(text, encoding="utf-8")
|
||||
|
||||
# 3) keywords.py defaults = dataset-only
|
||||
kw = root / "backend/app/api/routes/keywords.py"
|
||||
text = kw.read_text(encoding="utf-8")
|
||||
text = text.replace(" scan_hunts: bool = True", " scan_hunts: bool = False")
|
||||
text = text.replace(" scan_annotations: bool = True", " scan_annotations: bool = False")
|
||||
text = text.replace(" scan_messages: bool = True", " scan_messages: bool = False")
|
||||
kw.write_text(text, encoding="utf-8")
|
||||
|
||||
# 4) job_queue.py dedupe + periodic cleanup
|
||||
jq = root / "backend/app/services/job_queue.py"
|
||||
text = jq.read_text(encoding="utf-8")
|
||||
|
||||
text = text.replace(
|
||||
"from typing import Any, Callable, Coroutine, Optional\n",
|
||||
"from typing import Any, Callable, Coroutine, Optional\n\nfrom app.config import settings\n"
|
||||
)
|
||||
|
||||
text = text.replace(
|
||||
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n",
|
||||
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n self._cleanup_task: asyncio.Task | None = None\n"
|
||||
)
|
||||
|
||||
start_old = ''' async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
'''
|
||||
start_new = ''' async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
if not self._cleanup_task or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
'''
|
||||
text = text.replace(start_old, start_new)
|
||||
|
||||
stop_old = ''' async def stop(self):
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Job queue stopped")
|
||||
'''
|
||||
stop_new = ''' async def stop(self):
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||
self._cleanup_task = None
|
||||
logger.info("Job queue stopped")
|
||||
'''
|
||||
text = text.replace(stop_old, stop_new)
|
||||
|
||||
submit_old = ''' def submit(self, job_type: JobType, **params) -> Job:
|
||||
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
return job
|
||||
'''
|
||||
submit_new = ''' def submit(self, job_type: JobType, **params) -> Job:
|
||||
# Soft backpressure: prefer dedupe over queue amplification
|
||||
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||
if dedupe_job is not None:
|
||||
logger.info(
|
||||
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||
)
|
||||
return dedupe_job
|
||||
|
||||
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||
logger.warning(
|
||||
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||
)
|
||||
|
||||
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
return job
|
||||
'''
|
||||
text = text.replace(submit_old, submit_new)
|
||||
|
||||
insert_methods_after = " def get_job(self, job_id: str) -> Job | None:\n return self._jobs.get(job_id)\n"
|
||||
new_methods = ''' def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||
if not sig:
|
||||
return None
|
||||
for j in self._jobs.values():
|
||||
if j.job_type != job_type:
|
||||
continue
|
||||
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||
continue
|
||||
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||
if sig == other_sig:
|
||||
return j
|
||||
return None
|
||||
'''
|
||||
text = text.replace(insert_methods_after, new_methods)
|
||||
|
||||
cleanup_old = ''' def cleanup(self, max_age_seconds: float = 3600):
|
||||
now = time.time()
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
'''
|
||||
cleanup_new = ''' def cleanup(self, max_age_seconds: float = 3600):
|
||||
now = time.time()
|
||||
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
|
||||
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||
terminal_jobs = sorted(
|
||||
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||
key=lambda j: j.created_at,
|
||||
reverse=True,
|
||||
)
|
||||
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||
to_remove.extend([j.id for j in overflow])
|
||||
|
||||
removed = 0
|
||||
for jid in set(to_remove):
|
||||
if jid in self._jobs:
|
||||
del self._jobs[jid]
|
||||
removed += 1
|
||||
if removed:
|
||||
logger.info(f"Cleaned up {removed} old jobs")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||
while self._started:
|
||||
try:
|
||||
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
'''
|
||||
text = text.replace(cleanup_old, cleanup_new)
|
||||
|
||||
jq.write_text(text, encoding="utf-8")
|
||||
|
||||
# 5) NetworkMap polling backoff/jitter max wait
|
||||
nm = root / "frontend/src/components/NetworkMap.tsx"
|
||||
text = nm.read_text(encoding="utf-8")
|
||||
|
||||
text = text.replace(
|
||||
" // Poll until ready, then re-fetch\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n }\n",
|
||||
" // Poll until ready (exponential backoff), then re-fetch\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n throw new Error('Host inventory build timed out after 5 minutes');\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n"
|
||||
)
|
||||
|
||||
text = text.replace(
|
||||
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status every 2s until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n // still building or none (job may not have started yet) - keep polling\n } catch { if (cancelled) return false; }\n }\n };\n",
|
||||
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status with exponential backoff until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n setError('Host inventory build timed out. Please retry.');\n return false;\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n // still building or none (job may not have started yet) - keep polling\n } catch {\n if (cancelled) return false;\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n }\n };\n"
|
||||
)
|
||||
|
||||
nm.write_text(text, encoding="utf-8")
|
||||
|
||||
print("Patched: config.py, scanner.py, keywords.py, job_queue.py, NetworkMap.tsx")
|
||||
207
_apply_phase2_patch.py
Normal file
207
_apply_phase2_patch.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
root = Path(r"d:\Projects\Dev\ThreatHunt")
|
||||
|
||||
# ---------- config.py ----------
|
||||
cfg = root / "backend/app/config.py"
|
||||
text = cfg.read_text(encoding="utf-8")
|
||||
marker = " JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
|
||||
add = marker + "\n # -- Startup throttling ------------------------------------------------\n STARTUP_WARMUP_MAX_HUNTS: int = Field(\n default=5, description=\"Max hunts to warm inventory cache for at startup\"\n )\n STARTUP_REPROCESS_MAX_DATASETS: int = Field(\n default=25, description=\"Max unprocessed datasets to enqueue at startup\"\n )\n\n # -- Network API scale guards -----------------------------------------\n NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(\n default=400, description=\"Hard cap for hosts returned by network subgraph endpoint\"\n )\n NETWORK_SUBGRAPH_MAX_EDGES: int = Field(\n default=3000, description=\"Hard cap for edges returned by network subgraph endpoint\"\n )\n"
|
||||
if marker in text and "STARTUP_WARMUP_MAX_HUNTS" not in text:
|
||||
text = text.replace(marker, add)
|
||||
cfg.write_text(text, encoding="utf-8")
|
||||
|
||||
# ---------- job_queue.py ----------
|
||||
jq = root / "backend/app/services/job_queue.py"
|
||||
text = jq.read_text(encoding="utf-8")
|
||||
|
||||
# add helper methods after get_stats
|
||||
anchor = " def get_stats(self) -> dict:\n by_status = {}\n for j in self._jobs.values():\n by_status[j.status.value] = by_status.get(j.status.value, 0) + 1\n return {\n \"total\": len(self._jobs),\n \"queued\": self._queue.qsize(),\n \"by_status\": by_status,\n \"workers\": self._max_workers,\n \"active_workers\": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),\n }\n"
|
||||
if "def is_backlogged(" not in text:
|
||||
insert = anchor + "\n def is_backlogged(self) -> bool:\n return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG\n\n def can_accept(self, reserve: int = 0) -> bool:\n return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG\n"
|
||||
text = text.replace(anchor, insert)
|
||||
|
||||
jq.write_text(text, encoding="utf-8")
|
||||
|
||||
# ---------- host_inventory.py keyset pagination ----------
|
||||
hi = root / "backend/app/services/host_inventory.py"
|
||||
text = hi.read_text(encoding="utf-8")
|
||||
old = ''' batch_size = 5000
|
||||
offset = 0
|
||||
while True:
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
'''
|
||||
new = ''' batch_size = 5000
|
||||
last_row_index = -1
|
||||
while True:
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.where(DatasetRow.row_index > last_row_index)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
'''
|
||||
if old in text:
|
||||
text = text.replace(old, new)
|
||||
text = text.replace(" offset += batch_size\n if len(rows) < batch_size:\n break\n", " last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n")
|
||||
hi.write_text(text, encoding="utf-8")
|
||||
|
||||
# ---------- network.py add summary/subgraph + backpressure ----------
|
||||
net = root / "backend/app/api/routes/network.py"
|
||||
text = net.read_text(encoding="utf-8")
|
||||
text = text.replace("from fastapi import APIRouter, Depends, HTTPException, Query", "from fastapi import APIRouter, Depends, HTTPException, Query")
|
||||
if "from app.config import settings" not in text:
|
||||
text = text.replace("from app.db import get_db\n", "from app.config import settings\nfrom app.db import get_db\n")
|
||||
|
||||
# add helpers and endpoints before inventory-status endpoint
|
||||
if "def _build_summary" not in text:
|
||||
helper_block = '''
|
||||
|
||||
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||
return {
|
||||
"stats": inv.get("stats", {}),
|
||||
"top_hosts": [
|
||||
{
|
||||
"id": h.get("id"),
|
||||
"hostname": h.get("hostname"),
|
||||
"row_count": h.get("row_count", 0),
|
||||
"ip_count": len(h.get("ips", [])),
|
||||
"user_count": len(h.get("users", [])),
|
||||
}
|
||||
for h in top_hosts
|
||||
],
|
||||
"top_edges": top_edges,
|
||||
}
|
||||
|
||||
|
||||
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
|
||||
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||
|
||||
if node_id:
|
||||
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||
ids = {node_id}
|
||||
for c in rel_edges:
|
||||
ids.add(c.get("source"))
|
||||
ids.add(c.get("target"))
|
||||
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||
else:
|
||||
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||
allowed = {h.get("id") for h in rel_hosts}
|
||||
rel_edges = [
|
||||
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||
if c.get("source") in allowed and c.get("target") in allowed
|
||||
][:max_edges]
|
||||
|
||||
return {
|
||||
"hosts": rel_hosts,
|
||||
"connections": rel_edges,
|
||||
"stats": {
|
||||
**inv.get("stats", {}),
|
||||
"subgraph_hosts": len(rel_hosts),
|
||||
"subgraph_connections": len(rel_edges),
|
||||
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_inventory_summary(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
top_n: int = Query(20, ge=1, le=200),
|
||||
):
|
||||
"""Return a lightweight summary view for large hunts."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_summary(cached, top_n=top_n)
|
||||
|
||||
|
||||
@router.get("/subgraph")
|
||||
async def get_inventory_subgraph(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
node_id: str | None = Query(None, description="Optional focal node"),
|
||||
max_hosts: int = Query(200, ge=1, le=5000),
|
||||
max_edges: int = Query(1500, ge=1, le=20000),
|
||||
):
|
||||
"""Return a bounded subgraph for scale-safe rendering."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||
'''
|
||||
text = text.replace("\n\n@router.get(\"/inventory-status\")", helper_block + "\n\n@router.get(\"/inventory-status\")")
|
||||
|
||||
# add backpressure in host-inventory enqueue points
|
||||
text = text.replace(
|
||||
" if not inventory_cache.is_building(hunt_id):\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
|
||||
" if not inventory_cache.is_building(hunt_id):\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
|
||||
)
|
||||
text = text.replace(
|
||||
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
|
||||
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
|
||||
)
|
||||
|
||||
net.write_text(text, encoding="utf-8")
|
||||
|
||||
# ---------- analysis.py backpressure on manual submit ----------
|
||||
analysis = root / "backend/app/api/routes/analysis.py"
|
||||
text = analysis.read_text(encoding="utf-8")
|
||||
text = text.replace(
|
||||
" job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}",
|
||||
" if not job_queue.can_accept():\n raise HTTPException(status_code=429, detail=\"Job queue is busy. Retry shortly.\")\n job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}"
|
||||
)
|
||||
analysis.write_text(text, encoding="utf-8")
|
||||
|
||||
# ---------- main.py startup throttles ----------
|
||||
main = root / "backend/app/main.py"
|
||||
text = main.read_text(encoding="utf-8")
|
||||
|
||||
text = text.replace(
|
||||
" for hid in hunt_ids:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if hunt_ids:\n logger.info(f\"Queued host inventory warm-up for {len(hunt_ids)} hunts\")",
|
||||
" warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]\n for hid in warm_hunts:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if warm_hunts:\n logger.info(f\"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})\")"
|
||||
)
|
||||
|
||||
text = text.replace(
|
||||
" if unprocessed_ids:\n for ds_id in unprocessed_ids:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(unprocessed_ids))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()",
|
||||
" if unprocessed_ids:\n to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]\n for ds_id in to_reprocess:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(to_reprocess))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()"
|
||||
)
|
||||
|
||||
main.write_text(text, encoding="utf-8")
|
||||
|
||||
print("Patched Phase 2 files")
|
||||
75
_aup_add_dataset_scope_ui.py
Normal file
75
_aup_add_dataset_scope_ui.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# default selection when hunt changes: first 3 datasets instead of all
|
||||
old=''' datasets.list(0, 500, selectedHuntId).then(res => {
|
||||
if (cancelled) return;
|
||||
setDsList(res.datasets);
|
||||
setSelectedDs(new Set(res.datasets.map(d => d.id)));
|
||||
}).catch(() => {});
|
||||
'''
|
||||
new=''' datasets.list(0, 500, selectedHuntId).then(res => {
|
||||
if (cancelled) return;
|
||||
setDsList(res.datasets);
|
||||
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
|
||||
}).catch(() => {});
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('hunt-change dataset init block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
# insert dataset scope multi-select under hunt info
|
||||
anchor=''' {!selectedHuntId && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
All datasets will be scanned if no hunt is selected
|
||||
</Typography>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* Theme selector */}
|
||||
'''
|
||||
insert=''' {!selectedHuntId && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
Select a hunt to enable scoped scanning
|
||||
</Typography>
|
||||
)}
|
||||
|
||||
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
|
||||
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
|
||||
<Select
|
||||
labelId="aup-dataset-label"
|
||||
multiple
|
||||
value={Array.from(selectedDs)}
|
||||
label="Datasets"
|
||||
renderValue={(selected) => `${(selected as string[]).length} selected`}
|
||||
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
|
||||
>
|
||||
{dsList.map(d => (
|
||||
<MenuItem key={d.id} value={d.id}>
|
||||
<Checkbox size="small" checked={selectedDs.has(d.id)} />
|
||||
<Typography variant="body2" sx={{ ml: 0.5 }}>
|
||||
{d.name} ({d.row_count.toLocaleString()} rows)
|
||||
</Typography>
|
||||
</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
|
||||
{selectedHuntId && dsList.length > 0 && (
|
||||
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
|
||||
</Stack>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* Theme selector */}
|
||||
'''
|
||||
if anchor not in t:
|
||||
raise SystemExit('dataset scope anchor not found')
|
||||
t=t.replace(anchor,insert)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added AUP dataset multi-select scoping and safer defaults')
|
||||
182
_aup_add_host_user_to_hits.py
Normal file
182
_aup_add_host_user_to_hits.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# 1) Extend ScanHit dataclass
|
||||
old='''@dataclass
|
||||
class ScanHit:
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
'''
|
||||
new='''@dataclass
|
||||
class ScanHit:
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('ScanHit dataclass block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
# 2) Add helper to infer hostname/user from a row
|
||||
insert_after='''BATCH_SIZE = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanHit:
|
||||
'''
|
||||
helper='''BATCH_SIZE = 200
|
||||
|
||||
|
||||
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||
if not data:
|
||||
return None, None
|
||||
|
||||
host_keys = (
|
||||
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||
)
|
||||
user_keys = (
|
||||
'username', 'user_name', 'user', 'account_name',
|
||||
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||
)
|
||||
|
||||
def pick(keys):
|
||||
for k in keys:
|
||||
for actual_key, v in data.items():
|
||||
if actual_key.lower() == k and v not in (None, ''):
|
||||
return str(v)
|
||||
return None
|
||||
|
||||
return pick(host_keys), pick(user_keys)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanHit:
|
||||
'''
|
||||
if insert_after in t and '_infer_hostname_and_user' not in t:
|
||||
t=t.replace(insert_after,helper)
|
||||
|
||||
# 3) Extend _match_text signature and ScanHit construction
|
||||
old_sig=''' def _match_text(
|
||||
self,
|
||||
text: str,
|
||||
patterns: dict,
|
||||
source_type: str,
|
||||
source_id: str | int,
|
||||
field_name: str,
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
) -> None:
|
||||
'''
|
||||
new_sig=''' def _match_text(
|
||||
self,
|
||||
text: str,
|
||||
patterns: dict,
|
||||
source_type: str,
|
||||
source_id: str | int,
|
||||
field_name: str,
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
hostname: str | None = None,
|
||||
username: str | None = None,
|
||||
) -> None:
|
||||
'''
|
||||
if old_sig not in t:
|
||||
raise SystemExit('_match_text signature not found')
|
||||
t=t.replace(old_sig,new_sig)
|
||||
|
||||
old_hit=''' hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
keyword=kw_value,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
field=field_name,
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
))
|
||||
'''
|
||||
new_hit=''' hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
keyword=kw_value,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
field=field_name,
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
))
|
||||
'''
|
||||
if old_hit not in t:
|
||||
raise SystemExit('ScanHit append block not found')
|
||||
t=t.replace(old_hit,new_hit)
|
||||
|
||||
# 4) Pass inferred hostname/username in dataset scan path
|
||||
old_call=''' for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_name,
|
||||
)
|
||||
'''
|
||||
new_call=''' for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
hostname, username = _infer_hostname_and_user(data)
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
)
|
||||
'''
|
||||
if old_call not in t:
|
||||
raise SystemExit('dataset _match_text call block not found')
|
||||
t=t.replace(old_call,new_call)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated scanner hits with hostname+username context')
|
||||
32
_aup_extend_scanhit_api.py
Normal file
32
_aup_extend_scanhit_api.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old='''class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
'''
|
||||
new='''class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('ScanHit pydantic model block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('extended API ScanHit model with hostname+username')
|
||||
21
_aup_extend_scanhit_frontend_type.py
Normal file
21
_aup_extend_scanhit_frontend_type.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old='''export interface ScanHit {
|
||||
theme_name: string; theme_color: string; keyword: string;
|
||||
source_type: string; source_id: string | number; field: string;
|
||||
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||
}
|
||||
'''
|
||||
new='''export interface ScanHit {
|
||||
theme_name: string; theme_color: string; keyword: string;
|
||||
source_type: string; source_id: string | number; field: string;
|
||||
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||
hostname?: string | null; username?: string | null;
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('frontend ScanHit interface block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('extended frontend ScanHit type with hostname+username')
|
||||
57
_aup_keywords_scope_and_missing.py
Normal file
57
_aup_keywords_scope_and_missing.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# add fast guard against unscoped global dataset scans
|
||||
insert_after='''async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):\n scanner = KeywordScanner(db)\n\n'''
|
||||
if insert_after not in t:
|
||||
raise SystemExit('run_scan header block not found')
|
||||
if 'Select at least one dataset' not in t:
|
||||
guard=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:\n raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")\n\n'''
|
||||
t=t.replace(insert_after, insert_after+guard)
|
||||
|
||||
old=''' if missing:
|
||||
missing_entries: list[dict] = []
|
||||
for dataset_id in missing:
|
||||
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
|
||||
keyword_scan_cache.put(dataset_id, partial)
|
||||
missing_entries.append({"result": partial, "built_at": None})
|
||||
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + missing_entries,
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
'''
|
||||
new=''' if missing:
|
||||
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + [{"result": partial, "built_at": None}],
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('partial-cache missing block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('hardened keywords scan scope + optimized missing-cache path')
|
||||
18
_aup_reduce_budget.py
Normal file
18
_aup_reduce_budget.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||
default=300000,
|
||||
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||
)
|
||||
'''
|
||||
new=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||
)
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('SCANNER_MAX_ROWS_PER_SCAN block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('reduced SCANNER_MAX_ROWS_PER_SCAN default to 120000')
|
||||
42
_aup_update_grid_columns.py
Normal file
42
_aup_update_grid_columns.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old='''const RESULT_COLUMNS: GridColDef[] = [
|
||||
{
|
||||
field: 'theme_name', headerName: 'Theme', width: 140,
|
||||
renderCell: (params) => (
|
||||
<Chip label={params.value} size="small"
|
||||
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
|
||||
),
|
||||
},
|
||||
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
|
||||
{ field: 'field', headerName: 'Field', width: 130 },
|
||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
|
||||
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
|
||||
];
|
||||
'''
|
||||
new='''const RESULT_COLUMNS: GridColDef[] = [
|
||||
{
|
||||
field: 'theme_name', headerName: 'Theme', width: 140,
|
||||
renderCell: (params) => (
|
||||
<Chip label={params.value} size="small"
|
||||
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
|
||||
),
|
||||
},
|
||||
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
|
||||
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
|
||||
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
|
||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
|
||||
{ field: 'field', headerName: 'Field', width: 130 },
|
||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
|
||||
];
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('RESULT_COLUMNS block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated AUP results grid columns with dataset/hostname/user/matched value focus')
|
||||
40
_edit_aup.py
Normal file
40
_edit_aup.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
t=t.replace(' const [scanHunts, setScanHunts] = useState(true);',' const [scanHunts, setScanHunts] = useState(false);')
|
||||
t=t.replace(' const [scanAnnotations, setScanAnnotations] = useState(true);',' const [scanAnnotations, setScanAnnotations] = useState(false);')
|
||||
t=t.replace(' const [scanMessages, setScanMessages] = useState(true);',' const [scanMessages, setScanMessages] = useState(false);')
|
||||
t=t.replace(' scan_messages: scanMessages,\n });',' scan_messages: scanMessages,\n prefer_cache: true,\n });')
|
||||
# add cache chip in summary alert
|
||||
old=''' {scanResult && (
|
||||
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
|
||||
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||
</Alert>
|
||||
)}
|
||||
'''
|
||||
new=''' {scanResult && (
|
||||
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
|
||||
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||
{scanResult.cache_status && (
|
||||
<Chip
|
||||
size="small"
|
||||
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
|
||||
sx={{ ml: 1, height: 20 }}
|
||||
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
|
||||
variant="outlined"
|
||||
/>
|
||||
)}
|
||||
</Alert>
|
||||
)}
|
||||
'''
|
||||
if old in t:
|
||||
t=t.replace(old,new)
|
||||
else:
|
||||
print('warning: summary block not replaced')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated AUPScanner.tsx')
|
||||
36
_edit_client.py
Normal file
36
_edit_client.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
# Add HuntProgress interface after Hunt interface
|
||||
if 'export interface HuntProgress' not in t:
|
||||
insert = '''export interface HuntProgress {
|
||||
hunt_id: string;
|
||||
status: 'idle' | 'processing' | 'ready';
|
||||
progress_percent: number;
|
||||
dataset_total: number;
|
||||
dataset_completed: number;
|
||||
dataset_processing: number;
|
||||
dataset_errors: number;
|
||||
active_jobs: number;
|
||||
queued_jobs: number;
|
||||
network_status: 'none' | 'building' | 'ready';
|
||||
stages: Record<string, any>;
|
||||
}
|
||||
|
||||
'''
|
||||
t=t.replace('export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n', 'export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n'+insert)
|
||||
|
||||
# Add hunts.progress method
|
||||
if 'progress: (id: string)' not in t:
|
||||
t=t.replace(" delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n};", " delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),\n};")
|
||||
|
||||
# Extend ScanResponse
|
||||
if 'cache_used?: boolean' not in t:
|
||||
t=t.replace('export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n}\n', 'export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n cache_used?: boolean; cache_status?: string; cached_at?: string | null;\n}\n')
|
||||
|
||||
# Extend keywords.scan opts
|
||||
t=t.replace(' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n }) =>', ' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n prefer_cache?: boolean; force_rescan?: boolean;\n }) =>')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated client.ts')
|
||||
20
_edit_config_reconcile.py
Normal file
20
_edit_config_reconcile.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
anchor=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||
)
|
||||
'''
|
||||
insert=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||
)
|
||||
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
|
||||
default=True,
|
||||
description="Mark stale queued/running processing tasks as failed on startup",
|
||||
)
|
||||
'''
|
||||
if anchor not in t:
|
||||
raise SystemExit('startup anchor not found')
|
||||
t=t.replace(anchor,insert)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated config with STARTUP_RECONCILE_STALE_TASKS')
|
||||
39
_edit_datasets.py
Normal file
39
_edit_datasets.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'from app.services.scanner import keyword_scan_cache' not in t:
|
||||
t=t.replace('from app.services.host_inventory import inventory_cache','from app.services.host_inventory import inventory_cache\nfrom app.services.scanner import keyword_scan_cache')
|
||||
old='''@router.delete(
|
||||
"/{dataset_id}",
|
||||
summary="Delete a dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
'''
|
||||
new='''@router.delete(
|
||||
"/{dataset_id}",
|
||||
summary="Delete a dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('delete block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated datasets.py')
|
||||
110
_edit_datasets_tasks.py
Normal file
110
_edit_datasets_tasks.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'ProcessingTask' not in t:
|
||||
t=t.replace('from app.db.models import', 'from app.db.models import ProcessingTask\n# from app.db.models import')
|
||||
t=t.replace('from app.services.scanner import keyword_scan_cache','from app.services.scanner import keyword_scan_cache')
|
||||
# clean import replacement to proper single line
|
||||
if '# from app.db.models import' in t:
|
||||
t=t.replace('from app.db.models import ProcessingTask\n# from app.db.models import', 'from app.db.models import ProcessingTask')
|
||||
|
||||
old=''' # 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||
jobs_queued.append("triage")
|
||||
|
||||
# 2. Anomaly detection (embedding-based outlier detection)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||
jobs_queued.append("anomaly")
|
||||
|
||||
# 3. AUP keyword scan
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||
jobs_queued.append("keyword_scan")
|
||||
|
||||
# 4. IOC extraction
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||
jobs_queued.append("ioc_extract")
|
||||
|
||||
# 5. Host inventory (network map) - requires hunt_id
|
||||
if hunt_id:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
jobs_queued.append("host_inventory")
|
||||
'''
|
||||
new=''' task_rows: list[ProcessingTask] = []
|
||||
|
||||
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||
jobs_queued.append("triage")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=triage_job.id,
|
||||
stage="triage",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 2. Anomaly detection (embedding-based outlier detection)
|
||||
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||
jobs_queued.append("anomaly")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=anomaly_job.id,
|
||||
stage="anomaly",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 3. AUP keyword scan
|
||||
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||
jobs_queued.append("keyword_scan")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=kw_job.id,
|
||||
stage="keyword_scan",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 4. IOC extraction
|
||||
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||
jobs_queued.append("ioc_extract")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=ioc_job.id,
|
||||
stage="ioc_extract",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 5. Host inventory (network map) - requires hunt_id
|
||||
if hunt_id:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
jobs_queued.append("host_inventory")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=inv_job.id,
|
||||
stage="host_inventory",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
if task_rows:
|
||||
db.add_all(task_rows)
|
||||
await db.flush()
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('queue block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated datasets upload queue + processing tasks')
|
||||
254
_edit_hunts.py
Normal file
254
_edit_hunts.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
|
||||
new='''"""API routes for hunt management."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Dataset
|
||||
from app.services.job_queue import job_queue
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
owner_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
dataset_count: int = 0
|
||||
hypothesis_count: int = 0
|
||||
|
||||
|
||||
class HuntListResponse(BaseModel):
|
||||
hunts: list[HuntResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class HuntProgressResponse(BaseModel):
|
||||
hunt_id: str
|
||||
status: str
|
||||
progress_percent: float
|
||||
dataset_total: int
|
||||
dataset_completed: int
|
||||
dataset_processing: int
|
||||
dataset_errors: int
|
||||
active_jobs: int
|
||||
queued_jobs: int
|
||||
network_status: str
|
||||
stages: dict
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
|
||||
hunt = Hunt(name=body.name, description=body.description)
|
||||
db.add(hunt)
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=HuntListResponse, summary="List hunts")
|
||||
async def list_hunts(
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
|
||||
if status:
|
||||
stmt = stmt.where(Hunt.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hunts = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hunt.id))
|
||||
if status:
|
||||
count_stmt = count_stmt.where(Hunt.status == status)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HuntListResponse(
|
||||
hunts=[
|
||||
HuntResponse(
|
||||
id=h.id,
|
||||
name=h.name,
|
||||
description=h.description,
|
||||
status=h.status,
|
||||
owner_id=h.owner_id,
|
||||
created_at=h.created_at.isoformat(),
|
||||
updated_at=h.updated_at.isoformat(),
|
||||
dataset_count=len(h.datasets) if h.datasets else 0,
|
||||
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
|
||||
)
|
||||
for h in hunts
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
|
||||
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
|
||||
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
hunt = await db.get(Hunt, hunt_id)
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
ds_rows = await db.execute(
|
||||
select(Dataset.id, Dataset.processing_status)
|
||||
.where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_rows.all()
|
||||
dataset_ids = {row[0] for row in datasets}
|
||||
|
||||
dataset_total = len(datasets)
|
||||
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||
|
||||
jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
network_status = "ready"
|
||||
network_ratio = 1.0
|
||||
elif inventory_cache.is_building(hunt_id):
|
||||
network_status = "building"
|
||||
network_ratio = 0.5
|
||||
else:
|
||||
network_status = "none"
|
||||
network_ratio = 0.0
|
||||
|
||||
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
|
||||
status = "ready"
|
||||
if dataset_total == 0:
|
||||
status = "idle"
|
||||
elif progress_percent < 100:
|
||||
status = "processing"
|
||||
|
||||
stages = {
|
||||
"datasets": {
|
||||
"total": dataset_total,
|
||||
"completed": dataset_completed,
|
||||
"processing": dataset_processing,
|
||||
"errors": dataset_errors,
|
||||
"percent": round(dataset_ratio * 100.0, 1),
|
||||
},
|
||||
"network": {
|
||||
"status": network_status,
|
||||
"percent": round(network_ratio * 100.0, 1),
|
||||
},
|
||||
"jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
},
|
||||
}
|
||||
|
||||
return HuntProgressResponse(
|
||||
hunt_id=hunt_id,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
dataset_total=dataset_total,
|
||||
dataset_completed=dataset_completed,
|
||||
dataset_processing=dataset_processing,
|
||||
dataset_errors=dataset_errors,
|
||||
active_jobs=active_jobs,
|
||||
queued_jobs=queued_jobs,
|
||||
network_status=network_status,
|
||||
stages=stages,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
if body.name is not None:
|
||||
hunt.name = body.name
|
||||
if body.description is not None:
|
||||
hunt.description = body.description
|
||||
if body.status is not None:
|
||||
hunt.status = body.status
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{hunt_id}", summary="Delete a hunt")
|
||||
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
await db.delete(hunt)
|
||||
return {"message": "Hunt deleted", "id": hunt_id}
|
||||
'''
|
||||
p.write_text(new,encoding='utf-8')
|
||||
print('updated hunts.py')
|
||||
102
_edit_hunts_progress_tasks.py
Normal file
102
_edit_hunts_progress_tasks.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'ProcessingTask' not in t:
|
||||
t=t.replace('from app.db.models import Hunt, Dataset','from app.db.models import Hunt, Dataset, ProcessingTask')
|
||||
|
||||
old=''' jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
'''
|
||||
new=''' jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
task_rows = await db.execute(
|
||||
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||
.where(ProcessingTask.hunt_id == hunt_id)
|
||||
)
|
||||
tasks = task_rows.all()
|
||||
|
||||
task_total = len(tasks)
|
||||
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||
|
||||
active_jobs = max(active_jobs_mem, task_running)
|
||||
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||
|
||||
stage_rollup: dict[str, dict] = {}
|
||||
for stage, status, progress in tasks:
|
||||
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||
bucket["total"] += 1
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
bucket["done"] += 1
|
||||
elif status == "running":
|
||||
bucket["running"] += 1
|
||||
elif status == "queued":
|
||||
bucket["queued"] += 1
|
||||
bucket["progress_sum"] += float(progress or 0.0)
|
||||
|
||||
for stage_name, bucket in stage_rollup.items():
|
||||
total = max(1, bucket["total"])
|
||||
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('job block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
old2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
'''
|
||||
new2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
if task_ratio is None:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
else:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('ratio block not found')
|
||||
t=t.replace(old2,new2)
|
||||
|
||||
old3=''' "jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
},
|
||||
}
|
||||
'''
|
||||
new3=''' "jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
"task_total": task_total,
|
||||
"task_done": task_done,
|
||||
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||
},
|
||||
"task_stages": stage_rollup,
|
||||
}
|
||||
'''
|
||||
if old3 not in t:
|
||||
raise SystemExit('stages jobs block not found')
|
||||
t=t.replace(old3,new3)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated hunt progress to merge persistent processing tasks')
|
||||
46
_edit_job_queue.py
Normal file
46
_edit_job_queue.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old='''async def _handle_keyword_scan(job: Job):
|
||||
"""AUP keyword scan handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
|
||||
hits = result.get("total_hits", 0)
|
||||
job.message = f"Keyword scan complete: {hits} hits"
|
||||
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||
'''
|
||||
new='''async def _handle_keyword_scan(job: Job):
|
||||
"""AUP keyword scan handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
|
||||
# Cache dataset-only result for fast API reuse
|
||||
if dataset_id:
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
|
||||
hits = result.get("total_hits", 0)
|
||||
job.message = f"Keyword scan complete: {hits} hits"
|
||||
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('target block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated job_queue keyword scan handler')
|
||||
13
_edit_jobqueue_reconcile.py
Normal file
13
_edit_jobqueue_reconcile.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
marker='''def register_all_handlers():
|
||||
"""Register all job handlers and completion callbacks."""
|
||||
'''
|
||||
ins='''\n\nasync def reconcile_stale_processing_tasks() -> int:\n """Mark queued/running processing tasks from prior runs as failed."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n now = datetime.now(timezone.utc)\n async with async_session_factory() as db:\n result = await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.status.in_([\"queued\", \"running\"]))\n .values(\n status=\"failed\",\n error=\"Recovered after service restart before task completion\",\n message=\"Recovered stale task after restart\",\n completed_at=now,\n )\n )\n await db.commit()\n updated = int(result.rowcount or 0)\n\n if updated:\n logger.warning(\n \"Reconciled %d stale processing tasks (queued/running -> failed) during startup\",\n updated,\n )\n return updated\n except Exception as e:\n logger.warning(f\"Failed to reconcile stale processing tasks: {e}\")\n return 0\n\n\n'''
|
||||
if ins.strip() not in t:
|
||||
if marker not in t:
|
||||
raise SystemExit('register marker not found')
|
||||
t=t.replace(marker,ins+marker)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added reconcile_stale_processing_tasks to job_queue')
|
||||
64
_edit_jobqueue_sync.py
Normal file
64
_edit_jobqueue_sync.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
ins='''\n\nasync def _sync_processing_task(job: Job):\n """Persist latest job state into processing_tasks (if linked by job_id)."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n values = {\n "status": job.status.value,\n "progress": float(job.progress),\n "message": job.message,\n "error": job.error,\n }\n if job.started_at:\n values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)\n if job.completed_at:\n values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)\n\n async with async_session_factory() as db:\n await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.job_id == job.id)\n .values(**values)\n )\n await db.commit()\n except Exception as e:\n logger.warning(f"Failed to sync processing task for job {job.id}: {e}")\n'''
|
||||
marker='\n\n# -- Singleton + job handlers --\n'
|
||||
if ins.strip() not in t:
|
||||
t=t.replace(marker, ins+marker)
|
||||
|
||||
old=''' job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
job.message = "Running..."
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
'''
|
||||
new=''' job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
if job.progress <= 0:
|
||||
job.progress = 5.0
|
||||
job.message = "Running..."
|
||||
await _sync_processing_task(job)
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('worker running block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
old2=''' job.completed_at = time.time()
|
||||
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||
|
||||
# Fire completion callbacks
|
||||
'''
|
||||
new2=''' job.completed_at = time.time()
|
||||
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||
|
||||
if job.is_cancelled and not job.completed_at:
|
||||
job.completed_at = time.time()
|
||||
|
||||
await _sync_processing_task(job)
|
||||
|
||||
# Fire completion callbacks
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('worker completion block not found')
|
||||
t=t.replace(old2,new2)
|
||||
|
||||
p.write_text(t, encoding='utf-8')
|
||||
print('updated job_queue persistent task syncing')
|
||||
39
_edit_jobqueue_triage_task.py
Normal file
39
_edit_jobqueue_triage_task.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' if hunt_id:
|
||||
job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||
except Exception as e:
|
||||
'''
|
||||
new=''' if hunt_id:
|
||||
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import ProcessingTask
|
||||
async with async_session_factory() as db:
|
||||
existing = await db.execute(
|
||||
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||
)
|
||||
if existing.first() is None:
|
||||
db.add(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=hp_job.id,
|
||||
stage="host_profile",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
await db.commit()
|
||||
except Exception as persist_err:
|
||||
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||
|
||||
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||
except Exception as e:
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('triage chain block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated triage chain to persist host_profile task row')
|
||||
321
_edit_keywords.py
Normal file
321
_edit_keywords.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||
new_text='''"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ThemeUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
color: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class KeywordOut(BaseModel):
|
||||
id: int
|
||||
theme_id: str
|
||||
value: str
|
||||
is_regex: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class ThemeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
color: str
|
||||
enabled: bool
|
||||
is_builtin: bool
|
||||
created_at: str
|
||||
keyword_count: int
|
||||
keywords: list[KeywordOut]
|
||||
|
||||
|
||||
class ThemeListResponse(BaseModel):
|
||||
themes: list[ThemeOut]
|
||||
total: int
|
||||
|
||||
|
||||
class KeywordCreate(BaseModel):
|
||||
value: str = Field(..., min_length=1, max_length=256)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KeywordBulkCreate(BaseModel):
|
||||
values: list[str] = Field(..., min_items=1)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None
|
||||
theme_ids: list[str] | None = None
|
||||
scan_hunts: bool = False
|
||||
scan_annotations: bool = False
|
||||
scan_messages: bool = False
|
||||
prefer_cache: bool = True
|
||||
force_rescan: bool = False
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
total_hits: int
|
||||
hits: list[ScanHit]
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
cache_used: bool = False
|
||||
cache_status: str = "miss"
|
||||
cached_at: str | None = None
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
return ThemeOut(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
color=t.color,
|
||||
enabled=t.enabled,
|
||||
is_builtin=t.is_builtin,
|
||||
created_at=t.created_at.isoformat(),
|
||||
keyword_count=len(t.keywords),
|
||||
keywords=[
|
||||
KeywordOut(
|
||||
id=k.id,
|
||||
theme_id=k.theme_id,
|
||||
value=k.value,
|
||||
is_regex=k.is_regex,
|
||||
created_at=k.created_at.isoformat(),
|
||||
)
|
||||
for k in t.keywords
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||
hits: list[dict] = []
|
||||
total_rows = 0
|
||||
cached_at: str | None = None
|
||||
|
||||
for entry in entries:
|
||||
result = entry["result"]
|
||||
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||
if entry.get("built_at"):
|
||||
if not cached_at or entry["built_at"] > cached_at:
|
||||
cached_at = entry["built_at"]
|
||||
for h in result.get("hits", []):
|
||||
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||
continue
|
||||
hits.append(h)
|
||||
|
||||
return {
|
||||
"total_hits": len(hits),
|
||||
"hits": hits,
|
||||
"rows_scanned": total_rows,
|
||||
"cached_at": cached_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme.name = body.name
|
||||
if body.color is not None:
|
||||
theme.color = body.color
|
||||
if body.enabled is not None:
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
keyword_scan_cache.clear()
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
added = 0
|
||||
for val in body.values:
|
||||
val = val.strip()
|
||||
if not val:
|
||||
continue
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
keyword_scan_cache.clear()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
can_use_cache = (
|
||||
body.prefer_cache
|
||||
and not body.force_rescan
|
||||
and bool(body.dataset_ids)
|
||||
and not body.scan_hunts
|
||||
and not body.scan_annotations
|
||||
and not body.scan_messages
|
||||
)
|
||||
|
||||
if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if entry is not None:
|
||||
result = entry.result
|
||||
return {
|
||||
**result,
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": entry.built_at,
|
||||
}
|
||||
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
'''
|
||||
p.write_text(new_text,encoding='utf-8')
|
||||
print('updated keywords.py')
|
||||
31
_edit_main_reconcile.py
Normal file
31
_edit_main_reconcile.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/main.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' # Start job queue
|
||||
from app.services.job_queue import job_queue, register_all_handlers, JobType
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
'''
|
||||
new=''' # Start job queue
|
||||
from app.services.job_queue import (
|
||||
job_queue,
|
||||
register_all_handlers,
|
||||
reconcile_stale_processing_tasks,
|
||||
JobType,
|
||||
)
|
||||
|
||||
if settings.STARTUP_RECONCILE_STALE_TASKS:
|
||||
reconciled = await reconcile_stale_processing_tasks()
|
||||
if reconciled:
|
||||
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
|
||||
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('startup queue block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('wired startup reconciliation in main lifespan')
|
||||
45
_edit_models_processing.py
Normal file
45
_edit_models_processing.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/db/models.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'class ProcessingTask(Base):' in t:
|
||||
print('processing task model already exists')
|
||||
raise SystemExit(0)
|
||||
insert='''
|
||||
|
||||
# -- Persistent Processing Tasks (Phase 2) ---
|
||||
|
||||
class ProcessingTask(Base):
|
||||
__tablename__ = "processing_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
|
||||
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
|
||||
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
|
||||
)
|
||||
'''
|
||||
# insert before Playbook section
|
||||
marker='\n\n# -- Playbook / Investigation Templates (Feature 3) ---\n'
|
||||
if marker not in t:
|
||||
raise SystemExit('marker not found for insertion')
|
||||
t=t.replace(marker, insert+marker)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added ProcessingTask model')
|
||||
59
_edit_networkmap_hit.py
Normal file
59
_edit_networkmap_hit.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
insert='''
|
||||
function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean {
|
||||
const fontSize = Math.max(9, Math.round(12 / vp.scale));
|
||||
const approxCharW = Math.max(5, fontSize * 0.58);
|
||||
const line1 = node.label || '';
|
||||
const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : '';
|
||||
const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0);
|
||||
const px = 5, py = 2;
|
||||
const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2;
|
||||
const lx = node.x, ly = node.y - node.radius - 6;
|
||||
const rx = lx - tw / 2 - px;
|
||||
const ry = ly - totalH;
|
||||
const rw = tw + px * 2;
|
||||
const rh = totalH;
|
||||
return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh);
|
||||
}
|
||||
|
||||
'''
|
||||
if 'function isPointOnNodeLabel' not in t:
|
||||
t=t.replace('// == Hit-test =============================================================\n', '// == Hit-test =============================================================\n'+insert)
|
||||
|
||||
old='''function hitTest(
|
||||
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
|
||||
): GNode | null {
|
||||
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
|
||||
for (const n of graph.nodes) {
|
||||
const dx = n.x - wx, dy = n.y - wy;
|
||||
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
'''
|
||||
new='''function hitTest(
|
||||
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
|
||||
): GNode | null {
|
||||
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
|
||||
|
||||
// Node-circle hit has priority
|
||||
for (const n of graph.nodes) {
|
||||
const dx = n.x - wx, dy = n.y - wy;
|
||||
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
|
||||
}
|
||||
|
||||
// Then label hit (so clicking text works too)
|
||||
for (const n of graph.nodes) {
|
||||
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('hitTest block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated NetworkMap hit-test for labels')
|
||||
272
_edit_scanner.py
Normal file
272
_edit_scanner.py
Normal file
@@ -0,0 +1,272 @@
|
||||
from pathlib import Path
|
||||
|
||||
p = Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||
text = p.read_text(encoding='utf-8')
|
||||
new_text = '''"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanHit:
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
total_hits: int = 0
|
||||
hits: list[ScanHit] = field(default_factory=list)
|
||||
themes_scanned: int = 0
|
||||
keywords_scanned: int = 0
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordScanCacheEntry:
|
||||
dataset_id: str
|
||||
result: dict
|
||||
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
class KeywordScanCache:
|
||||
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||
|
||||
This enables fast-path reads when users run AUP scans against datasets that
|
||||
were already scanned during upload pipeline processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||
|
||||
def put(self, dataset_id: str, result: dict):
|
||||
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||
|
||||
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||
return self._entries.get(dataset_id)
|
||||
|
||||
def invalidate_dataset(self, dataset_id: str):
|
||||
self._entries.pop(dataset_id, None)
|
||||
|
||||
def clear(self):
|
||||
self._entries.clear()
|
||||
|
||||
|
||||
keyword_scan_cache = KeywordScanCache()
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# Public API
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = False,
|
||||
scan_annotations: bool = False,
|
||||
scan_messages: bool = False,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
themes = await self._load_themes(theme_ids)
|
||||
if not themes:
|
||||
return ScanResult().__dict__
|
||||
|
||||
# Pre-compile patterns per theme
|
||||
patterns = self._compile_patterns(themes)
|
||||
result = ScanResult(
|
||||
themes_scanned=len(themes),
|
||||
keywords_scanned=sum(len(kws) for kws in patterns.values()),
|
||||
)
|
||||
|
||||
# Scan dataset rows
|
||||
await self._scan_datasets(patterns, result, dataset_ids)
|
||||
|
||||
# Scan hunts
|
||||
if scan_hunts:
|
||||
await self._scan_hunts(patterns, result)
|
||||
|
||||
# Scan annotations
|
||||
if scan_annotations:
|
||||
await self._scan_annotations(patterns, result)
|
||||
|
||||
# Scan messages
|
||||
if scan_messages:
|
||||
await self._scan_messages(patterns, result)
|
||||
|
||||
result.total_hits = len(result.hits)
|
||||
return {
|
||||
"total_hits": result.total_hits,
|
||||
"hits": [h.__dict__ for h in result.hits],
|
||||
"themes_scanned": result.themes_scanned,
|
||||
"keywords_scanned": result.keywords_scanned,
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# Internal
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
if theme_ids:
|
||||
q = q.where(KeywordTheme.id.in_(theme_ids))
|
||||
result = await self.db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def _compile_patterns(
|
||||
self, themes: list[KeywordTheme]
|
||||
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
|
||||
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
|
||||
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
|
||||
for theme in themes:
|
||||
key = (theme.id, theme.name, theme.color)
|
||||
compiled = []
|
||||
for kw in theme.keywords:
|
||||
try:
|
||||
if kw.is_regex:
|
||||
pat = re.compile(kw.value, re.IGNORECASE)
|
||||
else:
|
||||
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
|
||||
compiled.append((kw.value, pat))
|
||||
except re.error:
|
||||
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
|
||||
kw.value, theme.name)
|
||||
patterns[key] = compiled
|
||||
return patterns
|
||||
|
||||
def _match_text(
|
||||
self,
|
||||
text: str,
|
||||
patterns: dict,
|
||||
source_type: str,
|
||||
source_id: str | int,
|
||||
field_name: str,
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
return
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
keyword=kw_value,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
field=field_name,
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
ds_result = await self.db.execute(ds_q)
|
||||
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
hunts_result = await self.db.execute(select(Hunt))
|
||||
for hunt in hunts_result.scalars().all():
|
||||
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
|
||||
if hunt.description:
|
||||
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
|
||||
|
||||
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan annotation text."""
|
||||
ann_result = await self.db.execute(select(Annotation))
|
||||
for ann in ann_result.scalars().all():
|
||||
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
|
||||
|
||||
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan conversation messages (user messages only)."""
|
||||
msg_result = await self.db.execute(
|
||||
select(Message).where(Message.role == "user")
|
||||
)
|
||||
for msg in msg_result.scalars().all():
|
||||
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
|
||||
'''
|
||||
|
||||
p.write_text(new_text, encoding='utf-8')
|
||||
print('updated scanner.py')
|
||||
31
_edit_test_api.py
Normal file
31
_edit_test_api.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_api.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
insert='''
|
||||
async def test_hunt_progress(self, client):
|
||||
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
|
||||
hunt_id = create.json()["id"]
|
||||
|
||||
# attach one dataset so progress has scope
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
res = await client.get(f"/api/hunts/{hunt_id}/progress")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["hunt_id"] == hunt_id
|
||||
assert "progress_percent" in body
|
||||
assert "dataset_total" in body
|
||||
assert "network_status" in body
|
||||
'''
|
||||
needle=''' async def test_get_nonexistent_hunt(self, client):
|
||||
resp = await client.get("/api/hunts/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
'''
|
||||
if needle in t and 'test_hunt_progress' not in t:
|
||||
t=t.replace(needle, needle+'\n'+insert)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated test_api.py')
|
||||
32
_edit_test_keywords.py
Normal file
32
_edit_test_keywords.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_keywords.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
add='''
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_scan_cache_hit(client: AsyncClient):
|
||||
"""Second quick scan should return cache hit metadata."""
|
||||
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert first.status_code == 200
|
||||
assert first.json().get("cache_status") in ("miss", "hit")
|
||||
|
||||
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert second.status_code == 200
|
||||
body = second.json()
|
||||
assert body.get("cache_used") is True
|
||||
assert body.get("cache_status") == "hit"
|
||||
'''
|
||||
if 'test_quick_scan_cache_hit' not in t:
|
||||
t=t + add
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated test_keywords.py')
|
||||
26
_edit_upload.py
Normal file
26
_edit_upload.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
# import useEffect
|
||||
t=t.replace("import React, { useState, useCallback, useRef } from 'react';","import React, { useState, useCallback, useRef, useEffect } from 'react';")
|
||||
# import HuntProgress type
|
||||
t=t.replace("import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';","import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client';")
|
||||
# add state
|
||||
if 'const [huntProgress, setHuntProgress]' not in t:
|
||||
t=t.replace(" const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');"," const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');\n const [huntProgress, setHuntProgress] = useState<HuntProgress | null>(null);")
|
||||
# add polling effect after hunts list effect
|
||||
marker=" React.useEffect(() => {\n hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});\n }, []);\n"
|
||||
if marker in t and 'setInterval' not in t.split(marker,1)[1][:500]:
|
||||
add='''\n useEffect(() => {\n let timer: any = null;\n let cancelled = false;\n\n const pull = async () => {\n if (!huntId) {\n if (!cancelled) setHuntProgress(null);\n return;\n }\n try {\n const p = await hunts.progress(huntId);\n if (!cancelled) setHuntProgress(p);\n } catch {\n if (!cancelled) setHuntProgress(null);\n }\n };\n\n pull();\n if (huntId) timer = setInterval(pull, 2000);\n return () => { cancelled = true; if (timer) clearInterval(timer); };\n }, [huntId, jobs.length]);\n'''
|
||||
t=t.replace(marker, marker+add)
|
||||
|
||||
# insert master progress UI after overall summary
|
||||
insert_after=''' {overallTotal > 0 && (\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mt: 2 }}>\n <Typography variant="body2" color="text.secondary">\n {overallDone + overallErr} / {overallTotal} files processed\n {overallErr > 0 && ` ({overallErr} failed)`}\n </Typography>\n <Box sx={{ flexGrow: 1 }} />\n {overallDone + overallErr === overallTotal && overallTotal > 0 && (\n <Tooltip title="Clear completed">\n <IconButton size="small" onClick={clearCompleted}><ClearIcon fontSize="small" /></IconButton>\n </Tooltip>\n )}\n </Stack>\n )}\n'''
|
||||
add_block='''\n {huntId && huntProgress && (\n <Paper sx={{ p: 1.5, mt: 1.5 }}>\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>\n <Typography variant="body2" sx={{ fontWeight: 600 }}>\n Master Processing Progress\n </Typography>\n <Chip\n size="small"\n label={huntProgress.status.toUpperCase()}\n color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}\n variant="outlined"\n />\n <Box sx={{ flexGrow: 1 }} />\n <Typography variant="caption" color="text.secondary">\n {huntProgress.progress_percent.toFixed(1)}%\n </Typography>\n </Stack>\n <LinearProgress\n variant="determinate"\n value={Math.max(0, Math.min(100, huntProgress.progress_percent))}\n sx={{ height: 8, borderRadius: 4 }}\n />\n <Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>\n <Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />\n <Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />\n <Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />\n <Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />\n </Stack>\n </Paper>\n )}\n'''
|
||||
if insert_after in t:
|
||||
t=t.replace(insert_after, insert_after+add_block)
|
||||
else:
|
||||
print('warning: summary block not found')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated FileUpload.tsx')
|
||||
42
_edit_upload2.py
Normal file
42
_edit_upload2.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
marker=''' {/* Per-file progress list */}
|
||||
'''
|
||||
add=''' {huntId && huntProgress && (
|
||||
<Paper sx={{ p: 1.5, mt: 1.5 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>
|
||||
<Typography variant="body2" sx={{ fontWeight: 600 }}>
|
||||
Master Processing Progress
|
||||
</Typography>
|
||||
<Chip
|
||||
size="small"
|
||||
label={huntProgress.status.toUpperCase()}
|
||||
color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}
|
||||
variant="outlined"
|
||||
/>
|
||||
<Box sx={{ flexGrow: 1 }} />
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
{huntProgress.progress_percent.toFixed(1)}%
|
||||
</Typography>
|
||||
</Stack>
|
||||
<LinearProgress
|
||||
variant="determinate"
|
||||
value={Math.max(0, Math.min(100, huntProgress.progress_percent))}
|
||||
sx={{ height: 8, borderRadius: 4 }}
|
||||
/>
|
||||
<Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>
|
||||
<Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />
|
||||
<Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />
|
||||
<Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />
|
||||
<Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
'''
|
||||
if marker not in t:
|
||||
raise SystemExit('marker not found')
|
||||
t=t.replace(marker, add+marker)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('inserted master progress block')
|
||||
55
_enforce_scanner_budget.py
Normal file
55
_enforce_scanner_budget.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'from app.config import settings' not in t:
|
||||
t=t.replace('from sqlalchemy.ext.asyncio import AsyncSession\n','from sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\n')
|
||||
|
||||
old=''' import asyncio
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
last_id = 0
|
||||
while True:
|
||||
'''
|
||||
new=''' import asyncio
|
||||
|
||||
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||
budget_reached = False
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
|
||||
last_id = 0
|
||||
while True:
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('scanner loop block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
old2=''' if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
'''
|
||||
new2=''' if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
logger.warning(
|
||||
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||
result.rows_scanned,
|
||||
)
|
||||
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('scanner break block not found')
|
||||
t=t.replace(old2,new2,1)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('added scanner global row budget enforcement')
|
||||
12
_fix_aup_dep.py
Normal file
12
_fix_aup_dep.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||
'''
|
||||
new=''' }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('runScan deps block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('fixed AUPScanner runScan dependency list')
|
||||
7
_fix_import_datasets.py
Normal file
7
_fix_import_datasets.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'from app.db.models import ProcessingTask' not in t:
|
||||
t=t.replace('from app.db import get_db\n', 'from app.db import get_db\nfrom app.db.models import ProcessingTask\n')
|
||||
p.write_text(t, encoding='utf-8')
|
||||
print('added ProcessingTask import')
|
||||
25
_fix_keywords_empty_guard.py
Normal file
25
_fix_keywords_empty_guard.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||
raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")
|
||||
|
||||
'''
|
||||
new=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||
return {
|
||||
"total_hits": 0,
|
||||
"hits": [],
|
||||
"themes_scanned": 0,
|
||||
"keywords_scanned": 0,
|
||||
"rows_scanned": 0,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('scope guard block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('adjusted empty scan guard to return fast empty result (200)')
|
||||
47
_fix_label_selector_networkmap.py
Normal file
47
_fix_label_selector_networkmap.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# Add label selector in toolbar before refresh button
|
||||
insert_after=""" <TextField
|
||||
size=\"small\"
|
||||
placeholder=\"Search hosts, IPs, users\\u2026\"
|
||||
value={search}
|
||||
onChange={e => setSearch(e.target.value)}
|
||||
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
|
||||
slotProps={{
|
||||
input: {
|
||||
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
|
||||
},
|
||||
}}
|
||||
/>
|
||||
"""
|
||||
label_ctrl="""
|
||||
<FormControl size=\"small\" sx={{ minWidth: 150 }}>
|
||||
<InputLabel id=\"label-mode-selector\">Labels</InputLabel>
|
||||
<Select
|
||||
labelId=\"label-mode-selector\"
|
||||
value={labelMode}
|
||||
label=\"Labels\"
|
||||
onChange={e => setLabelMode(e.target.value as LabelMode)}
|
||||
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
|
||||
>
|
||||
<MenuItem value=\"none\">None</MenuItem>
|
||||
<MenuItem value=\"highlight\">Selected/Search</MenuItem>
|
||||
<MenuItem value=\"all\">All</MenuItem>
|
||||
</Select>
|
||||
</FormControl>
|
||||
"""
|
||||
if 'label-mode-selector' not in t:
|
||||
if insert_after not in t:
|
||||
raise SystemExit('search block not found for label selector insertion')
|
||||
t=t.replace(insert_after, insert_after+label_ctrl)
|
||||
|
||||
# Fix useCallback dependency for startAnimLoop
|
||||
old=' }, [canvasSize]);'
|
||||
new=' }, [canvasSize, labelMode]);'
|
||||
if old in t:
|
||||
t=t.replace(old,new,1)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('inserted label selector UI and fixed callback dependency')
|
||||
10
_fix_last_dep_networkmap.py
Normal file
10
_fix_last_dep_networkmap.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
count=t.count('}, [canvasSize]);')
|
||||
if count:
|
||||
t=t.replace('}, [canvasSize]);','}, [canvasSize, labelMode]);')
|
||||
# In case formatter created spaced variant
|
||||
t=t.replace('}, [canvasSize ]);','}, [canvasSize, labelMode]);')
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched remaining canvasSize callback deps:', count)
|
||||
71
_harden_aup_scope_ui.py
Normal file
71
_harden_aup_scope_ui.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# Auto-select first hunt with datasets after load
|
||||
old=''' const [tRes, hRes] = await Promise.all([
|
||||
keywords.listThemes(),
|
||||
hunts.list(0, 200),
|
||||
]);
|
||||
setThemes(tRes.themes);
|
||||
setHuntList(hRes.hunts);
|
||||
'''
|
||||
new=''' const [tRes, hRes] = await Promise.all([
|
||||
keywords.listThemes(),
|
||||
hunts.list(0, 200),
|
||||
]);
|
||||
setThemes(tRes.themes);
|
||||
setHuntList(hRes.hunts);
|
||||
if (!selectedHuntId && hRes.hunts.length > 0) {
|
||||
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
|
||||
setSelectedHuntId(best.id);
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('loadData block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
# Guard runScan
|
||||
old2=''' const runScan = useCallback(async () => {
|
||||
setScanning(true);
|
||||
setScanResult(null);
|
||||
try {
|
||||
'''
|
||||
new2=''' const runScan = useCallback(async () => {
|
||||
if (!selectedHuntId) {
|
||||
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
|
||||
return;
|
||||
}
|
||||
if (selectedDs.size === 0) {
|
||||
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
|
||||
return;
|
||||
}
|
||||
|
||||
setScanning(true);
|
||||
setScanResult(null);
|
||||
try {
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('runScan header not found')
|
||||
t=t.replace(old2,new2)
|
||||
|
||||
# update loadData deps
|
||||
old3=''' }, [enqueueSnackbar]);
|
||||
'''
|
||||
new3=''' }, [enqueueSnackbar, selectedHuntId]);
|
||||
'''
|
||||
if old3 not in t:
|
||||
raise SystemExit('loadData deps not found')
|
||||
t=t.replace(old3,new3,1)
|
||||
|
||||
# disable button if no hunt or no datasets
|
||||
old4=''' onClick={runScan} disabled={scanning}
|
||||
'''
|
||||
new4=''' onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
|
||||
'''
|
||||
if old4 not in t:
|
||||
raise SystemExit('scan button props not found')
|
||||
t=t.replace(old4,new4)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('hardened AUPScanner to require explicit hunt/dataset scope')
|
||||
84
_optimize_keywords_partial_cache.py
Normal file
84
_optimize_keywords_partial_cache.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
'''
|
||||
new=''' if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
if missing:
|
||||
missing_entries: list[dict] = []
|
||||
for dataset_id in missing:
|
||||
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
|
||||
keyword_scan_cache.put(dataset_id, partial)
|
||||
missing_entries.append({"result": partial, "built_at": None})
|
||||
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + missing_entries,
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('cache block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated keyword /scan to use partial cache + scan missing datasets only')
|
||||
61
_optimize_scanner_keyset.py
Normal file
61
_optimize_scanner_keyset.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
start=t.index(' async def _scan_datasets(')
|
||||
end=t.index(' async def _scan_hunts', start)
|
||||
new_func=''' async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
ds_result = await self.db.execute(ds_q)
|
||||
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
import asyncio
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
last_id = 0
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds_id)
|
||||
.where(DatasetRow.id > last_id)
|
||||
.order_by(DatasetRow.id)
|
||||
.limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_name,
|
||||
)
|
||||
|
||||
last_id = rows[-1].id
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
'''
|
||||
out=t[:start]+new_func+t[end:]
|
||||
p.write_text(out,encoding='utf-8')
|
||||
print('optimized scanner _scan_datasets to keyset pagination')
|
||||
36
_patch_inventory_stats.py
Normal file
36
_patch_inventory_stats.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
},
|
||||
}
|
||||
'''
|
||||
new=''' return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
|
||||
},
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('return block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched inventory stats metadata')
|
||||
10
_patch_inventory_stats2.py
Normal file
10
_patch_inventory_stats2.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
needle=' "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n'
|
||||
if '"row_budget_per_dataset"' not in t:
|
||||
if needle not in t:
|
||||
raise SystemExit('needle not found')
|
||||
t=t.replace(needle, needle + ' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,\n')
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('inserted inventory budget stats lines')
|
||||
14
_patch_network_sleep.py
Normal file
14
_patch_network_sleep.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pathlib import Path
|
||||
|
||||
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
|
||||
text = p.read_text(encoding="utf-8")
|
||||
|
||||
anchor = " useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize]);\n"
|
||||
insert = anchor + "\n const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
|
||||
if "const sleep = (ms: number)" not in text and anchor in text:
|
||||
text = text.replace(anchor, insert)
|
||||
|
||||
text = text.replace("await new Promise(r => setTimeout(r, delayMs + jitter));", "await sleep(delayMs + jitter);")
|
||||
|
||||
p.write_text(text, encoding="utf-8")
|
||||
print("Patched sleep helper + polling awaits")
|
||||
37
_patch_network_wait.py
Normal file
37
_patch_network_wait.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
|
||||
text = p.read_text(encoding="utf-8")
|
||||
pattern = re.compile(r"const waitUntilReady = async \(\): Promise<boolean> => \{[\s\S]*?\n\s*\};", re.M)
|
||||
replacement = '''const waitUntilReady = async (): Promise<boolean> => {
|
||||
// Poll inventory-status with exponential backoff until 'ready' (or cancelled)
|
||||
setProgress('Host inventory is being prepared in the background');
|
||||
setLoading(true);
|
||||
let delayMs = 1500;
|
||||
const startedAt = Date.now();
|
||||
for (;;) {
|
||||
const jitter = Math.floor(Math.random() * 250);
|
||||
await new Promise(r => setTimeout(r, delayMs + jitter));
|
||||
if (cancelled) return false;
|
||||
try {
|
||||
const st = await network.inventoryStatus(selectedHuntId);
|
||||
if (cancelled) return false;
|
||||
if (st.status === 'ready') return true;
|
||||
if (Date.now() - startedAt > 5 * 60 * 1000) {
|
||||
setError('Host inventory build timed out. Please retry.');
|
||||
return false;
|
||||
}
|
||||
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
|
||||
// still building or none (job may not have started yet) - keep polling
|
||||
} catch {
|
||||
if (cancelled) return false;
|
||||
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
|
||||
}
|
||||
}
|
||||
};'''
|
||||
new_text, n = pattern.subn(replacement, text, count=1)
|
||||
if n != 1:
|
||||
raise SystemExit(f"Failed to patch waitUntilReady, matches={n}")
|
||||
p.write_text(new_text, encoding="utf-8")
|
||||
print("Patched waitUntilReady")
|
||||
26
_perf_edit_config_inventory.py
Normal file
26
_perf_edit_config_inventory.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
old=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=25000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
'''
|
||||
new=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=5000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
|
||||
default=120000,
|
||||
description="Max unique connection tuples retained during host inventory build",
|
||||
)
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('network inventory block not found')
|
||||
t=t.replace(old,new)
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated network inventory budgets in config')
|
||||
164
_perf_edit_host_inventory_budgets.py
Normal file
164
_perf_edit_host_inventory_budgets.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# insert budget vars near existing counters
|
||||
old=''' connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
'''
|
||||
new=''' connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
sampled_dataset_count = 0
|
||||
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||
global_budget_reached = False
|
||||
dropped_connections = 0
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('counter block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
# update batch size and sampled count increments + global budget checks
|
||||
old2=''' batch_size = 10000
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
while True:
|
||||
'''
|
||||
new2=''' batch_size = 5000
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
while True:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
global_budget_reached = True
|
||||
break
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('batch block not found')
|
||||
t=t.replace(old2,new2)
|
||||
|
||||
old3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
'''
|
||||
new3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
'''
|
||||
if old3 not in t:
|
||||
raise SystemExit('row scan block not found')
|
||||
t=t.replace(old3,new3)
|
||||
|
||||
# cap connection map growth
|
||||
old4=''' for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
'''
|
||||
new4=''' for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
conn_key = (host_key, rip, rport)
|
||||
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||
dropped_connections += 1
|
||||
continue
|
||||
connections[conn_key] += 1
|
||||
'''
|
||||
if old4 not in t:
|
||||
raise SystemExit('connection block not found')
|
||||
t=t.replace(old4,new4)
|
||||
|
||||
# sampled_dataset counter
|
||||
old5=''' if sampled_dataset:
|
||||
logger.info(
|
||||
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
'''
|
||||
new5=''' if sampled_dataset:
|
||||
sampled_dataset_count += 1
|
||||
logger.info(
|
||||
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
'''
|
||||
if old5 not in t:
|
||||
raise SystemExit('sampled block not found')
|
||||
t=t.replace(old5,new5)
|
||||
|
||||
# break dataset loop if global budget reached
|
||||
old6=''' if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
'''
|
||||
new6=''' if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
if global_budget_reached:
|
||||
logger.info(
|
||||
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||
hunt_id,
|
||||
total_rows,
|
||||
)
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
'''
|
||||
if old6 not in t:
|
||||
raise SystemExit('post-process boundary block not found')
|
||||
t=t.replace(old6,new6)
|
||||
|
||||
# add stats
|
||||
old7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
|
||||
},
|
||||
}
|
||||
'''
|
||||
new7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||
"sampled_datasets": sampled_dataset_count,
|
||||
"global_budget_reached": global_budget_reached,
|
||||
"dropped_connections": dropped_connections,
|
||||
},
|
||||
}
|
||||
'''
|
||||
if old7 not in t:
|
||||
raise SystemExit('stats block not found')
|
||||
t=t.replace(old7,new7)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('updated host inventory with global row and connection budgets')
|
||||
39
_perf_edit_networkmap_render.py
Normal file
39
_perf_edit_networkmap_render.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
repls={
|
||||
"const LARGE_HUNT_SUBGRAPH_HOSTS = 350;":"const LARGE_HUNT_SUBGRAPH_HOSTS = 220;",
|
||||
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;":"const LARGE_HUNT_SUBGRAPH_EDGES = 1200;",
|
||||
"const RENDER_SIMPLIFY_NODE_THRESHOLD = 220;":"const RENDER_SIMPLIFY_NODE_THRESHOLD = 120;",
|
||||
"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;":"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500;",
|
||||
"const EDGE_DRAW_TARGET = 1000;":"const EDGE_DRAW_TARGET = 600;"
|
||||
}
|
||||
for a,b in repls.items():
|
||||
if a not in t:
|
||||
raise SystemExit(f'missing constant: {a}')
|
||||
t=t.replace(a,b)
|
||||
|
||||
old=''' // Then label hit (so clicking text works too)
|
||||
for (const n of graph.nodes) {
|
||||
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||
}
|
||||
'''
|
||||
new=''' // Then label hit (so clicking text works too on manageable graph sizes)
|
||||
if (graph.nodes.length <= 220) {
|
||||
for (const n of graph.nodes) {
|
||||
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
|
||||
}
|
||||
}
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('label hit block not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
old2='simulate(g, w / 2, h / 2, 60);'
|
||||
if t.count(old2) < 2:
|
||||
raise SystemExit('expected two simulate calls')
|
||||
t=t.replace(old2,'simulate(g, w / 2, h / 2, 20);',1)
|
||||
t=t.replace(old2,'simulate(g, w / 2, h / 2, 30);',1)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('tightened network map rendering + load limits')
|
||||
107
_perf_patch_backend.py
Normal file
107
_perf_patch_backend.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from pathlib import Path
|
||||
# config updates
|
||||
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=cfg.read_text(encoding='utf-8')
|
||||
anchor=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
'''
|
||||
ins=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=200000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
'''
|
||||
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
|
||||
if anchor not in t:
|
||||
raise SystemExit('config network anchor not found')
|
||||
t=t.replace(anchor,ins)
|
||||
cfg.write_text(t,encoding='utf-8')
|
||||
|
||||
# host inventory updates
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'from app.config import settings' not in t:
|
||||
t=t.replace('from app.db.models import Dataset, DatasetRow\n', 'from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
|
||||
|
||||
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
|
||||
|
||||
old=''' for ro in rows:
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
|
||||
fqdn = ''
|
||||
'''
|
||||
new=''' for ro in rows:
|
||||
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
|
||||
fqdn = ''
|
||||
'''
|
||||
if old not in t:
|
||||
raise SystemExit('row loop anchor not found')
|
||||
t=t.replace(old,new)
|
||||
|
||||
old2=''' last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
'''
|
||||
new2=''' if sampled_dataset:
|
||||
logger.info(
|
||||
"Host inventory row budget reached for dataset %s (%d rows)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
|
||||
last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
'''
|
||||
if old2 not in t:
|
||||
raise SystemExit('batch loop end anchor not found')
|
||||
t=t.replace(old2,new2)
|
||||
|
||||
old3=''' return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
},
|
||||
}
|
||||
'''
|
||||
new3=''' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0
|
||||
|
||||
return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"sampled_mode": sampled,
|
||||
},
|
||||
}
|
||||
'''
|
||||
if old3 not in t:
|
||||
raise SystemExit('return stats anchor not found')
|
||||
t=t.replace(old3,new3)
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched config + host inventory row budget')
|
||||
38
_perf_patch_backend2.py
Normal file
38
_perf_patch_backend2.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pathlib import Path
|
||||
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
|
||||
t=cfg.read_text(encoding='utf-8')
|
||||
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
|
||||
t=t.replace(
|
||||
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
''',
|
||||
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=200000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
''')
|
||||
cfg.write_text(t,encoding='utf-8')
|
||||
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
if 'from app.config import settings' not in t:
|
||||
t=t.replace('from app.db.models import Dataset, DatasetRow\n','from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
|
||||
|
||||
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n',
|
||||
' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
|
||||
|
||||
t=t.replace(' for ro in rows:\n data = ro.data or {}\n total_rows += 1\n\n',
|
||||
' for ro in rows:\n if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:\n sampled_dataset = True\n break\n\n data = ro.data or {}\n total_rows += 1\n rows_scanned_this_dataset += 1\n\n')
|
||||
|
||||
t=t.replace(' last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n',
|
||||
' if sampled_dataset:\n logger.info(\n "Host inventory row budget reached for dataset %s (%d rows)",\n ds.id,\n rows_scanned_this_dataset,\n )\n break\n\n last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n')
|
||||
|
||||
t=t.replace(' return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n },\n }\n',
|
||||
' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0\n\n return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": sampled,\n },\n }\n')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched backend inventory performance settings')
|
||||
220
_perf_patch_networkmap.py
Normal file
220
_perf_patch_networkmap.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
# constants
|
||||
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
|
||||
t=t.replace(
|
||||
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n",
|
||||
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n")
|
||||
|
||||
# drawBackground signature
|
||||
t_old='''function drawBackground(
|
||||
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
|
||||
) {
|
||||
'''
|
||||
if t_old in t:
|
||||
t=t.replace(t_old,
|
||||
'''function drawBackground(
|
||||
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
|
||||
simplify: boolean,
|
||||
) {
|
||||
''')
|
||||
|
||||
# skip grid when simplify
|
||||
if 'if (!simplify) {' not in t:
|
||||
t=t.replace(
|
||||
''' ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
ctx.fillStyle = GRID_DOT_COLOR;
|
||||
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||
}
|
||||
}
|
||||
ctx.restore();
|
||||
''',
|
||||
''' if (!simplify) {
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
ctx.fillStyle = GRID_DOT_COLOR;
|
||||
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||
}
|
||||
}
|
||||
ctx.restore();
|
||||
}
|
||||
''')
|
||||
|
||||
# drawEdges signature
|
||||
t=t.replace('''function drawEdges(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
nodeMap: Map<string, GNode>, animTime: number,
|
||||
) {
|
||||
for (const e of graph.edges) {
|
||||
''',
|
||||
'''function drawEdges(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
nodeMap: Map<string, GNode>, animTime: number,
|
||||
simplify: boolean,
|
||||
) {
|
||||
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
|
||||
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
|
||||
const e = graph.edges[ei];
|
||||
''')
|
||||
|
||||
# simplify edge path
|
||||
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
|
||||
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||
|
||||
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
|
||||
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||
|
||||
# reduce glow when simplify
|
||||
t=t.replace(''' ctx.save();
|
||||
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
|
||||
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
|
||||
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
|
||||
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
|
||||
ctx.stroke(); ctx.restore();
|
||||
''',
|
||||
''' if (!simplify) {
|
||||
ctx.save();
|
||||
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
|
||||
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
|
||||
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
|
||||
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
|
||||
ctx.stroke(); ctx.restore();
|
||||
}
|
||||
''')
|
||||
|
||||
# drawLabels signature and early return
|
||||
t=t.replace('''function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
) {
|
||||
''',
|
||||
'''function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
simplify: boolean,
|
||||
) {
|
||||
''')
|
||||
|
||||
if 'if (simplify && !search && !hovered && !selected) {' not in t:
|
||||
t=t.replace(' const dimmed = search.length > 0;\n',
|
||||
' const dimmed = search.length > 0;\n if (simplify && !search && !hovered && !selected) {\n return;\n }\n')
|
||||
|
||||
# drawGraph adapt
|
||||
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
|
||||
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
|
||||
ctx.restore();
|
||||
''',
|
||||
''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
|
||||
drawBackground(ctx, w, h, vp, dpr, simplify);
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
|
||||
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
|
||||
ctx.restore();
|
||||
''')
|
||||
|
||||
# hover RAF ref
|
||||
if 'const hoverRafRef = useRef<number>(0);' not in t:
|
||||
t=t.replace(' const graphRef = useRef<Graph | null>(null);\n', ' const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
|
||||
|
||||
# throttle hover hit test on mousemove
|
||||
old_mm=''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||
setHovered(node?.id ?? null);
|
||||
}, [graph, redraw, startAnimLoop]);
|
||||
'''
|
||||
new_mm=''' cancelAnimationFrame(hoverRafRef.current);
|
||||
const clientX = e.clientX;
|
||||
const clientY = e.clientY;
|
||||
hoverRafRef.current = requestAnimationFrame(() => {
|
||||
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
|
||||
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
|
||||
});
|
||||
}, [graph, redraw, startAnimLoop]);
|
||||
'''
|
||||
if old_mm in t:
|
||||
t=t.replace(old_mm,new_mm)
|
||||
|
||||
# cleanup hover raf on unmount in existing animation cleanup effect
|
||||
if 'cancelAnimationFrame(hoverRafRef.current);' not in t:
|
||||
t=t.replace(''' useEffect(() => {
|
||||
if (graph) startAnimLoop();
|
||||
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
|
||||
}, [graph, startAnimLoop]);
|
||||
''',
|
||||
''' useEffect(() => {
|
||||
if (graph) startAnimLoop();
|
||||
return () => {
|
||||
cancelAnimationFrame(animFrameRef.current);
|
||||
cancelAnimationFrame(hoverRafRef.current);
|
||||
isAnimatingRef.current = false;
|
||||
};
|
||||
}, [graph, startAnimLoop]);
|
||||
''')
|
||||
|
||||
# connectedNodes optimization map
|
||||
if 'const nodeById = useMemo(() => {' not in t:
|
||||
t=t.replace(''' const connectionCount = selectedNode && graph
|
||||
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||
: 0;
|
||||
|
||||
const connectedNodes = useMemo(() => {
|
||||
''',
|
||||
''' const connectionCount = selectedNode && graph
|
||||
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||
: 0;
|
||||
|
||||
const nodeById = useMemo(() => {
|
||||
const m = new Map<string, GNode>();
|
||||
if (!graph) return m;
|
||||
for (const n of graph.nodes) m.set(n.id, n);
|
||||
return m;
|
||||
}, [graph]);
|
||||
|
||||
const connectedNodes = useMemo(() => {
|
||||
''')
|
||||
|
||||
t=t.replace(''' const n = graph.nodes.find(x => x.id === e.target);
|
||||
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||
} else if (e.target === selectedNode.id) {
|
||||
const n = graph.nodes.find(x => x.id === e.source);
|
||||
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||
''',
|
||||
''' const n = nodeById.get(e.target);
|
||||
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||
} else if (e.target === selectedNode.id) {
|
||||
const n = nodeById.get(e.source);
|
||||
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
|
||||
''')
|
||||
|
||||
t=t.replace(' }, [selectedNode, graph]);\n', ' }, [selectedNode, graph, nodeById]);\n')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched NetworkMap adaptive render + hover throttle')
|
||||
153
_perf_patch_networkmap2.py
Normal file
153
_perf_patch_networkmap2.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
|
||||
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
|
||||
t=t.replace('const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n', 'const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n')
|
||||
|
||||
t=t.replace('function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n) {', 'function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n simplify: boolean,\n) {')
|
||||
|
||||
t=t.replace(''' ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
ctx.fillStyle = GRID_DOT_COLOR;
|
||||
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||
}
|
||||
}
|
||||
ctx.restore();
|
||||
''',''' if (!simplify) {
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
const startX = -vp.x / vp.scale - GRID_SPACING;
|
||||
const startY = -vp.y / vp.scale - GRID_SPACING;
|
||||
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
|
||||
ctx.fillStyle = GRID_DOT_COLOR;
|
||||
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
|
||||
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
|
||||
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
|
||||
}
|
||||
}
|
||||
ctx.restore();
|
||||
}
|
||||
''')
|
||||
|
||||
t=t.replace('''function drawEdges(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
nodeMap: Map<string, GNode>, animTime: number,
|
||||
) {
|
||||
for (const e of graph.edges) {
|
||||
''','''function drawEdges(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
nodeMap: Map<string, GNode>, animTime: number,
|
||||
simplify: boolean,
|
||||
) {
|
||||
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
|
||||
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
|
||||
const e = graph.edges[ei];
|
||||
''')
|
||||
|
||||
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
|
||||
|
||||
t=t.replace('''function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
) {
|
||||
const dimmed = search.length > 0;
|
||||
''','''function drawLabels(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null,
|
||||
search: string, matchSet: Set<string>, vp: Viewport,
|
||||
simplify: boolean,
|
||||
) {
|
||||
const dimmed = search.length > 0;
|
||||
if (simplify && !search && !hovered && !selected) {
|
||||
return;
|
||||
}
|
||||
''')
|
||||
|
||||
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
|
||||
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
|
||||
ctx.restore();
|
||||
''',''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
|
||||
drawBackground(ctx, w, h, vp, dpr, simplify);
|
||||
ctx.save();
|
||||
ctx.translate(vp.x * dpr, vp.y * dpr);
|
||||
ctx.scale(vp.scale * dpr, vp.scale * dpr);
|
||||
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
|
||||
drawNodes(ctx, graph, hovered, selected, search, matchSet);
|
||||
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
|
||||
ctx.restore();
|
||||
''')
|
||||
|
||||
if 'const hoverRafRef = useRef<number>(0);' not in t:
|
||||
t=t.replace('const graphRef = useRef<Graph | null>(null);\n', 'const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
|
||||
|
||||
t=t.replace(''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||
setHovered(node?.id ?? null);
|
||||
}, [graph, redraw, startAnimLoop]);
|
||||
''',''' cancelAnimationFrame(hoverRafRef.current);
|
||||
const clientX = e.clientX;
|
||||
const clientY = e.clientY;
|
||||
hoverRafRef.current = requestAnimationFrame(() => {
|
||||
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
|
||||
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
|
||||
});
|
||||
}, [graph, redraw, startAnimLoop]);
|
||||
''')
|
||||
|
||||
t=t.replace(''' useEffect(() => {
|
||||
if (graph) startAnimLoop();
|
||||
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
|
||||
}, [graph, startAnimLoop]);
|
||||
''',''' useEffect(() => {
|
||||
if (graph) startAnimLoop();
|
||||
return () => {
|
||||
cancelAnimationFrame(animFrameRef.current);
|
||||
cancelAnimationFrame(hoverRafRef.current);
|
||||
isAnimatingRef.current = false;
|
||||
};
|
||||
}, [graph, startAnimLoop]);
|
||||
''')
|
||||
|
||||
if 'const nodeById = useMemo(() => {' not in t:
|
||||
t=t.replace(''' const connectionCount = selectedNode && graph
|
||||
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||
: 0;
|
||||
|
||||
const connectedNodes = useMemo(() => {
|
||||
''',''' const connectionCount = selectedNode && graph
|
||||
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||
: 0;
|
||||
|
||||
const nodeById = useMemo(() => {
|
||||
const m = new Map<string, GNode>();
|
||||
if (!graph) return m;
|
||||
for (const n of graph.nodes) m.set(n.id, n);
|
||||
return m;
|
||||
}, [graph]);
|
||||
|
||||
const connectedNodes = useMemo(() => {
|
||||
''')
|
||||
|
||||
t=t.replace('const n = graph.nodes.find(x => x.id === e.target);','const n = nodeById.get(e.target);')
|
||||
t=t.replace('const n = graph.nodes.find(x => x.id === e.source);','const n = nodeById.get(e.source);')
|
||||
t=t.replace(' }, [selectedNode, graph]);',' }, [selectedNode, graph, nodeById]);')
|
||||
|
||||
p.write_text(t,encoding='utf-8')
|
||||
print('patched NetworkMap performance')
|
||||
227
_perf_replace_build_host_inventory.py
Normal file
227
_perf_replace_build_host_inventory.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from pathlib import Path
|
||||
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
|
||||
t=p.read_text(encoding='utf-8')
|
||||
start=t.index('async def build_host_inventory(')
|
||||
# find end of function by locating '\n\n' before EOF after ' }\n'
|
||||
end=t.index('\n\n', start)
|
||||
# need proper end: first double newline after function may occur in docstring? compute by searching for '\n\n' after ' }\n' near end
|
||||
ret_idx=t.rfind(' }')
|
||||
# safer locate end as last occurrence of '\n }\n' after start, then function ends next newline
|
||||
end=t.find('\n\n', ret_idx)
|
||||
if end==-1:
|
||||
end=len(t)
|
||||
new_func='''async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns dict with 'hosts', 'connections', and 'stats'.
|
||||
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
|
||||
"""
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
all_datasets = ds_result.scalars().all()
|
||||
|
||||
if not all_datasets:
|
||||
return {"hosts": [], "connections": [], "stats": {
|
||||
"total_hosts": 0, "total_datasets_scanned": 0,
|
||||
"total_rows_scanned": 0,
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
sampled_dataset_count = 0
|
||||
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||
global_budget_reached = False
|
||||
dropped_connections = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
|
||||
while True:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.where(DatasetRow.row_index > last_row_index)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
fqdn = _clean(data.get(c))
|
||||
if fqdn:
|
||||
break
|
||||
client_id = ''
|
||||
for c in cols['host_id']:
|
||||
client_id = _clean(data.get(c))
|
||||
if client_id:
|
||||
break
|
||||
|
||||
if not fqdn and not client_id:
|
||||
continue
|
||||
|
||||
host_key = fqdn or client_id
|
||||
|
||||
if host_key not in hosts:
|
||||
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
|
||||
hosts[host_key] = {
|
||||
'id': host_key,
|
||||
'hostname': short or client_id,
|
||||
'fqdn': fqdn,
|
||||
'client_id': client_id,
|
||||
'ips': set(),
|
||||
'os': '',
|
||||
'users': set(),
|
||||
'datasets': set(),
|
||||
'row_count': 0,
|
||||
}
|
||||
|
||||
h = hosts[host_key]
|
||||
h['datasets'].add(ds.name)
|
||||
h['row_count'] += 1
|
||||
if client_id and not h['client_id']:
|
||||
h['client_id'] = client_id
|
||||
|
||||
for c in cols['username']:
|
||||
u = _extract_username(_clean(data.get(c)))
|
||||
if u:
|
||||
h['users'].add(u)
|
||||
|
||||
for c in cols['local_ip']:
|
||||
ip = _clean(data.get(c))
|
||||
if _is_valid_ip(ip):
|
||||
h['ips'].add(ip)
|
||||
ip_to_host[ip] = host_key
|
||||
|
||||
for c in cols['os']:
|
||||
ov = _clean(data.get(c))
|
||||
if ov and not h['os']:
|
||||
h['os'] = ov
|
||||
|
||||
for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
conn_key = (host_key, rip, rport)
|
||||
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||
dropped_connections += 1
|
||||
continue
|
||||
connections[conn_key] += 1
|
||||
|
||||
if sampled_dataset:
|
||||
sampled_dataset_count += 1
|
||||
logger.info(
|
||||
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
|
||||
last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
if global_budget_reached:
|
||||
logger.info(
|
||||
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||
hunt_id,
|
||||
total_rows,
|
||||
)
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
h['os'] = _infer_os(h['fqdn'])
|
||||
h['ips'] = sorted(h['ips'])
|
||||
h['users'] = sorted(h['users'])
|
||||
h['datasets'] = sorted(h['datasets'])
|
||||
|
||||
# Build connections, resolving IPs to host keys
|
||||
conn_list = []
|
||||
seen = set()
|
||||
for (src, dst_ip, dst_port), cnt in connections.items():
|
||||
if dst_ip in _IGNORE_IPS:
|
||||
continue
|
||||
dst_host = ip_to_host.get(dst_ip, '')
|
||||
if dst_host == src:
|
||||
continue
|
||||
key = tuple(sorted([src, dst_host or dst_ip]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
conn_list.append({
|
||||
'source': src,
|
||||
'target': dst_host or dst_ip,
|
||||
'target_ip': dst_ip,
|
||||
'port': dst_port,
|
||||
'count': cnt,
|
||||
})
|
||||
|
||||
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
|
||||
|
||||
return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||
"sampled_datasets": sampled_dataset_count,
|
||||
"global_budget_reached": global_budget_reached,
|
||||
"dropped_connections": dropped_connections,
|
||||
},
|
||||
}
|
||||
'''
|
||||
out=t[:start]+new_func+t[end:]
|
||||
p.write_text(out,encoding='utf-8')
|
||||
print('replaced build_host_inventory with hard-budget fast mode')
|
||||
@@ -0,0 +1,78 @@
|
||||
"""add playbooks, playbook_steps, saved_searches tables
|
||||
|
||||
Revision ID: b2c3d4e5f6a7
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-02-21 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "b2c3d4e5f6a7"
|
||||
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add display_name to users table
|
||||
with op.batch_alter_table("users") as batch_op:
|
||||
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
|
||||
|
||||
# Create playbooks table
|
||||
op.create_table(
|
||||
"playbooks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("is_template", sa.Boolean(), server_default="0"),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("status", sa.String(20), server_default="active"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create playbook_steps table
|
||||
op.create_table(
|
||||
"playbook_steps",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("order_index", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(256), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("step_type", sa.String(32), server_default="manual"),
|
||||
sa.Column("target_route", sa.String(256), nullable=True),
|
||||
sa.Column("is_completed", sa.Boolean(), server_default="0"),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("notes", sa.Text(), nullable=True),
|
||||
)
|
||||
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
|
||||
|
||||
# Create saved_searches table
|
||||
op.create_table(
|
||||
"saved_searches",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("search_type", sa.String(32), nullable=False),
|
||||
sa.Column("query_params", sa.JSON(), nullable=False),
|
||||
sa.Column("threshold", sa.Float(), nullable=True),
|
||||
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("last_result_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("saved_searches")
|
||||
op.drop_table("playbook_steps")
|
||||
op.drop_table("playbooks")
|
||||
with op.batch_alter_table("users") as batch_op:
|
||||
batch_op.drop_column("display_name")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add processing_tasks table
|
||||
|
||||
Revision ID: c3d4e5f6a7b8
|
||||
Revises: b2c3d4e5f6a7
|
||||
Create Date: 2026-02-22 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "c3d4e5f6a7b8"
|
||||
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"processing_tasks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("job_id", sa.String(64), nullable=True),
|
||||
sa.Column("stage", sa.String(64), nullable=False),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("message", sa.Text(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
|
||||
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
|
||||
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
|
||||
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
|
||||
op.drop_table("processing_tasks")
|
||||
@@ -1,16 +1,18 @@
|
||||
"""Analyst-assist agent module for ThreatHunt.
|
||||
"""Analyst-assist agent module for ThreatHunt.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only and do not execute actions or modify data.
|
||||
"""
|
||||
|
||||
from .core import ThreatHuntAgent
|
||||
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider
|
||||
from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||
|
||||
__all__ = [
|
||||
"ThreatHuntAgent",
|
||||
"LLMProvider",
|
||||
"LocalProvider",
|
||||
"NetworkedProvider",
|
||||
"OnlineProvider",
|
||||
"AgentContext",
|
||||
"AgentResponse",
|
||||
"Perspective",
|
||||
"OllamaProvider",
|
||||
"OpenWebUIProvider",
|
||||
"EmbeddingProvider",
|
||||
]
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Core ThreatHunt analyst-assist agent.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only - no execution, no alerts, no data modifications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers import LLMProvider, get_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
|
||||
host_identifier: Optional[str] = Field(
|
||||
None, description="Host name, IP, or identifier"
|
||||
)
|
||||
data_summary: Optional[str] = Field(
|
||||
None, description="Brief description of uploaded data"
|
||||
)
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages in conversation"
|
||||
)
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
|
||||
)
|
||||
suggested_pivots: list[str] = Field(
|
||||
default_factory=list, description="Suggested analytical directions"
|
||||
)
|
||||
suggested_filters: list[str] = Field(
|
||||
default_factory=list, description="Suggested data filters or queries"
|
||||
)
|
||||
caveats: Optional[str] = Field(
|
||||
None, description="Assumptions, limitations, or caveats"
|
||||
)
|
||||
reasoning: Optional[str] = Field(
|
||||
None, description="Explanation of how guidance was generated"
|
||||
)
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent for ThreatHunt.
|
||||
|
||||
Provides guidance on:
|
||||
- Interpreting CSV artifact data
|
||||
- Suggesting analytical pivots and filters
|
||||
- Forming and testing hypotheses
|
||||
|
||||
Policy:
|
||||
- Advisory guidance only (no execution)
|
||||
- No database or schema changes
|
||||
- No alert escalation
|
||||
- Transparent reasoning
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Optional[LLMProvider] = None):
|
||||
"""Initialize agent with LLM provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider instance. If None, uses get_provider() with auto mode.
|
||||
"""
|
||||
if provider is None:
|
||||
try:
|
||||
provider = get_provider("auto")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize default provider: {e}")
|
||||
provider = None
|
||||
|
||||
self.provider = provider
|
||||
self.system_prompt = self._build_system_prompt()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt that governs agent behavior."""
|
||||
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You do NOT make autonomous decisions
|
||||
- You ONLY analyze data presented to you
|
||||
- You explain your reasoning transparently
|
||||
- You acknowledge limitations and assumptions
|
||||
- You suggest next investigative steps
|
||||
|
||||
When responding:
|
||||
1. Start with a clear, direct answer to the query
|
||||
2. Explain your reasoning based on the data context provided
|
||||
3. Suggest 2-4 analytical pivots the analyst might explore
|
||||
4. Suggest 2-4 data filters or queries that might be useful
|
||||
5. Include relevant caveats or assumptions
|
||||
6. Be honest about what you cannot determine from the data
|
||||
|
||||
Remember: The analyst is the decision-maker. You are an assistant."""
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis.
|
||||
|
||||
Args:
|
||||
context: Request context including query and data context.
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
if not self.provider:
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
|
||||
# Build prompt with context
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
try:
|
||||
# Get guidance from LLM provider
|
||||
guidance = await self.provider.generate(prompt, max_tokens=1024)
|
||||
|
||||
# Parse response into structured format
|
||||
response = self._parse_response(guidance, context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted with query: {context.query[:50]}... "
|
||||
f"(dataset: {context.dataset_name})"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating guidance: {e}")
|
||||
raise
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt for the LLM."""
|
||||
prompt_parts = [
|
||||
f"Analyst query: {context.query}",
|
||||
]
|
||||
|
||||
if context.dataset_name:
|
||||
prompt_parts.append(f"Dataset: {context.dataset_name}")
|
||||
|
||||
if context.artifact_type:
|
||||
prompt_parts.append(f"Artifact type: {context.artifact_type}")
|
||||
|
||||
if context.host_identifier:
|
||||
prompt_parts.append(f"Host: {context.host_identifier}")
|
||||
|
||||
if context.data_summary:
|
||||
prompt_parts.append(f"Data summary: {context.data_summary}")
|
||||
|
||||
if context.conversation_history:
|
||||
prompt_parts.append("\nConversation history:")
|
||||
for msg in context.conversation_history[-5:]: # Last 5 messages for context
|
||||
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM response into structured format.
|
||||
|
||||
Note: This is a simplified parser. In production, use structured output
|
||||
from the LLM (JSON mode, function calling, etc.) for better reliability.
|
||||
"""
|
||||
# For now, return a structured response based on the raw guidance
|
||||
# In production, parse JSON or use structured output from LLM
|
||||
return AgentResponse(
|
||||
guidance=response_text,
|
||||
confidence=0.8, # Placeholder
|
||||
suggested_pivots=[
|
||||
"Analyze temporal patterns",
|
||||
"Cross-reference with known indicators",
|
||||
"Examine outliers in the dataset",
|
||||
"Compare with baseline behavior",
|
||||
],
|
||||
suggested_filters=[
|
||||
"Filter by high-risk indicators",
|
||||
"Sort by timestamp for timeline analysis",
|
||||
"Group by host or user",
|
||||
"Filter by anomaly score",
|
||||
],
|
||||
caveats="Guidance is based on available data context. "
|
||||
"Analysts should verify findings with additional sources.",
|
||||
reasoning="Analysis generated based on artifact data patterns and analyst query.",
|
||||
)
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Pluggable LLM provider interface for analyst-assist agents.
|
||||
|
||||
Supports three provider types:
|
||||
- Local: On-device or on-prem models
|
||||
- Networked: Shared internal inference services
|
||||
- Online: External hosted APIs
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate a response from the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt
|
||||
max_tokens: Maximum tokens in response
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider backend is available."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalProvider(LLMProvider):
|
||||
"""Local LLM provider (on-device or on-prem models)."""
|
||||
|
||||
def __init__(self, model_path: Optional[str] = None):
|
||||
"""Initialize local provider.
|
||||
|
||||
Args:
|
||||
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
|
||||
"""
|
||||
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
|
||||
self.model = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if local model is available."""
|
||||
if not self.model_path:
|
||||
return False
|
||||
# In production, would verify model file exists and can be loaded
|
||||
return os.path.exists(str(self.model_path))
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using local model.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- llama-cpp-python for GGML models
|
||||
- Ollama API
|
||||
- vLLM
|
||||
- Other local inference engines
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Local model not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Local model response to: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class NetworkedProvider(LLMProvider):
|
||||
"""Networked LLM provider (shared internal inference services)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "default",
|
||||
):
|
||||
"""Initialize networked provider.
|
||||
|
||||
Args:
|
||||
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
|
||||
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
|
||||
model_name: Model name/ID on the service.
|
||||
"""
|
||||
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
|
||||
self.model_name = model_name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if networked service is available."""
|
||||
return bool(self.api_endpoint)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using networked service.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- Internal inference service API
|
||||
- LLM inference container cluster
|
||||
- Enterprise inference gateway
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Networked service not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class OnlineProvider(LLMProvider):
|
||||
"""Online LLM provider (external hosted APIs)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: str = "openai",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize online provider.
|
||||
|
||||
Args:
|
||||
api_provider: Provider name (openai, anthropic, google, etc.)
|
||||
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
|
||||
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
|
||||
self.model_name = model_name or os.getenv(
|
||||
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if online API is available."""
|
||||
return bool(self.api_key)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using online API.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- OpenAI API (GPT-3.5, GPT-4, etc.)
|
||||
- Anthropic Claude API
|
||||
- Google Gemini API
|
||||
- Other hosted LLM services
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Online API not available or API key not set")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
|
||||
|
||||
|
||||
def get_provider(provider_type: str = "auto") -> LLMProvider:
|
||||
"""Get an LLM provider based on configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
|
||||
'auto' attempts to use the first available provider in order:
|
||||
local -> networked -> online.
|
||||
|
||||
Returns:
|
||||
Configured LLM provider instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
# Explicit provider selection
|
||||
if provider_type == "local":
|
||||
provider = LocalProvider()
|
||||
elif provider_type == "networked":
|
||||
provider = NetworkedProvider()
|
||||
elif provider_type == "online":
|
||||
provider = OnlineProvider()
|
||||
elif provider_type == "auto":
|
||||
# Try providers in order of preference
|
||||
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
|
||||
provider = Provider()
|
||||
if provider.is_available():
|
||||
return provider
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider.is_available():
|
||||
raise RuntimeError(f"{provider_type} provider not available")
|
||||
|
||||
return provider
|
||||
@@ -1,170 +0,0 @@
|
||||
"""API routes for analyst-assist agent."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
|
||||
from app.agents.config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance (lazy-loaded)
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
"""Get or create the agent instance."""
|
||||
global _agent
|
||||
if _agent is None:
|
||||
if not AgentConfig.is_agent_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Analyst-assist agent is not configured. "
|
||||
"Please configure an LLM provider.",
|
||||
)
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
"""Request for agent assistance."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: str | None = Field(
|
||||
None, description="Name of CSV dataset being analyzed"
|
||||
)
|
||||
artifact_type: str | None = Field(
|
||||
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
|
||||
)
|
||||
host_identifier: str | None = Field(
|
||||
None, description="Host name, IP address, or identifier"
|
||||
)
|
||||
data_summary: str | None = Field(
|
||||
None, description="Brief summary or context about the uploaded data"
|
||||
)
|
||||
conversation_history: list[dict] | None = Field(
|
||||
None, description="Previous messages for context"
|
||||
)
|
||||
|
||||
|
||||
class AssistResponse(BaseModel):
|
||||
"""Response with agent guidance."""
|
||||
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponse,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
|
||||
"Agent provides advisory guidance only - no execution.",
|
||||
)
|
||||
async def agent_assist(request: AssistRequest) -> AssistResponse:
|
||||
"""Provide analyst-assist guidance on artifact data.
|
||||
|
||||
The agent will:
|
||||
- Explain and interpret the provided data context
|
||||
- Suggest analytical pivots the analyst might explore
|
||||
- Suggest data filters or queries that might be useful
|
||||
- Highlight assumptions, limitations, and caveats
|
||||
|
||||
The agent will NOT:
|
||||
- Execute any tools or actions
|
||||
- Escalate findings to alerts
|
||||
- Modify any data or schema
|
||||
- Make autonomous decisions
|
||||
|
||||
Args:
|
||||
request: Assistance request with query and context
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not configured (503) or request fails
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
|
||||
# Build context
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
)
|
||||
|
||||
# Get guidance
|
||||
response = await agent.assist(context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted analyst with query: {request.query[:50]}... "
|
||||
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
|
||||
)
|
||||
|
||||
return AssistResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Agent error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Agent unavailable: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in agent_assist: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error generating guidance. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent health",
|
||||
description="Check if agent is configured and ready to assist.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
"""Check agent availability and configuration.
|
||||
|
||||
Returns:
|
||||
Health status with configuration details
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": provider_type,
|
||||
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
|
||||
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
|
||||
}
|
||||
except HTTPException:
|
||||
return {
|
||||
"status": "unavailable",
|
||||
"reason": "No LLM provider configured",
|
||||
"configured_providers": {
|
||||
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
|
||||
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
|
||||
"online": bool(AgentConfig.ONLINE_API_KEY),
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
"""API routes for analyst-assist agent v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.db.models import Conversation, Message, Dataset, KeywordTheme
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# Request / Response models
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
execution_preference: str = Field(default="auto", description="auto | force | off")
|
||||
learning_mode: bool = False
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
execution: dict | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
|
||||
POLICY_QUERY_TERMS = {
|
||||
"policy", "violating", "violation", "browser history", "web history",
|
||||
"domain", "domains", "adult", "gambling", "piracy", "aup",
|
||||
}
|
||||
WEB_DATASET_HINTS = {
|
||||
"web", "history", "browser", "url", "visited_url", "domain", "title",
|
||||
}
|
||||
|
||||
|
||||
def _is_policy_domain_query(query: str) -> bool:
|
||||
q = (query or "").lower()
|
||||
if not q:
|
||||
return False
|
||||
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
|
||||
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
|
||||
|
||||
def _should_execute_policy_scan(request: AssistRequest) -> bool:
|
||||
pref = (request.execution_preference or "auto").strip().lower()
|
||||
if pref == "off":
|
||||
return False
|
||||
if pref == "force":
|
||||
return True
|
||||
return _is_policy_domain_query(request.query)
|
||||
|
||||
|
||||
def _extract_domain(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(text)
|
||||
if parsed.netloc:
|
||||
return parsed.netloc.lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
|
||||
return m.group(0).lower() if m else None
|
||||
|
||||
|
||||
def _dataset_score(ds: Dataset) -> int:
|
||||
score = 0
|
||||
name = (ds.name or "").lower()
|
||||
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
|
||||
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
|
||||
|
||||
for h in WEB_DATASET_HINTS:
|
||||
if h in name:
|
||||
score += 2
|
||||
if h in cols_l:
|
||||
score += 3
|
||||
if h in norm_vals_l:
|
||||
score += 3
|
||||
|
||||
if "visited_url" in cols_l or "url" in cols_l:
|
||||
score += 8
|
||||
if "user" in cols_l or "username" in cols_l:
|
||||
score += 2
|
||||
if "clientid" in cols_l or "fqdn" in cols_l:
|
||||
score += 2
|
||||
if (ds.row_count or 0) > 0:
|
||||
score += 1
|
||||
|
||||
return score
|
||||
|
||||
|
||||
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
theme_result = await db.execute(
|
||||
select(KeywordTheme).where(
|
||||
KeywordTheme.enabled == True, # noqa: E712
|
||||
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
|
||||
)
|
||||
)
|
||||
themes = list(theme_result.scalars().all())
|
||||
theme_ids = [t.id for t in themes]
|
||||
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
|
||||
|
||||
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
|
||||
if request.hunt_id:
|
||||
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
|
||||
ds_result = await db.execute(ds_query)
|
||||
candidates = list(ds_result.scalars().all())
|
||||
|
||||
if request.dataset_name:
|
||||
needle = request.dataset_name.lower().strip()
|
||||
candidates = [d for d in candidates if needle in (d.name or "").lower()]
|
||||
|
||||
scored = sorted(
|
||||
((d, _dataset_score(d)) for d in candidates),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
selected = [d for d, s in scored if s > 0][:8]
|
||||
dataset_ids = [d.id for d in selected]
|
||||
|
||||
if not dataset_ids:
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": 0,
|
||||
"dataset_names": [],
|
||||
"total_hits": 0,
|
||||
"policy_hits": 0,
|
||||
"top_user_hosts": [],
|
||||
"top_domains": [],
|
||||
"sample_hits": [],
|
||||
"note": "No suitable browser/web-history datasets found in current scope.",
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=dataset_ids,
|
||||
theme_ids=theme_ids or None,
|
||||
scan_hunts=False,
|
||||
scan_annotations=False,
|
||||
scan_messages=False,
|
||||
)
|
||||
hits = result.get("hits", [])
|
||||
|
||||
user_host_counter = Counter()
|
||||
domain_counter = Counter()
|
||||
|
||||
for h in hits:
|
||||
user = h.get("username") or "(unknown-user)"
|
||||
host = h.get("hostname") or "(unknown-host)"
|
||||
user_host_counter[f"{user}|{host}"] += 1
|
||||
|
||||
dom = _extract_domain(h.get("matched_value"))
|
||||
if dom:
|
||||
domain_counter[dom] += 1
|
||||
|
||||
top_user_hosts = [
|
||||
{"user_host": k, "count": v}
|
||||
for k, v in user_host_counter.most_common(10)
|
||||
]
|
||||
top_domains = [
|
||||
{"domain": k, "count": v}
|
||||
for k, v in domain_counter.most_common(10)
|
||||
]
|
||||
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": len(dataset_ids),
|
||||
"dataset_names": [d.name for d in selected],
|
||||
"total_hits": int(result.get("total_hits", 0)),
|
||||
"policy_hits": int(result.get("total_hits", 0)),
|
||||
"rows_scanned": int(result.get("rows_scanned", 0)),
|
||||
"top_user_hosts": top_user_hosts,
|
||||
"top_domains": top_domains,
|
||||
"sample_hits": hits[:20],
|
||||
}
|
||||
|
||||
|
||||
# Routes
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -84,6 +252,76 @@ async def agent_assist(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
# Deterministic execution mode for policy-domain investigations.
|
||||
if _should_execute_policy_scan(request):
|
||||
t0 = time.monotonic()
|
||||
exec_payload = await _run_policy_domain_execution(request, db)
|
||||
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
policy_hits = exec_payload.get("policy_hits", 0)
|
||||
datasets_scanned = exec_payload.get("datasets_scanned", 0)
|
||||
|
||||
if policy_hits > 0:
|
||||
guidance = (
|
||||
f"Policy-violation scan complete: {policy_hits} hits across "
|
||||
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
|
||||
f"in execution results for triage."
|
||||
)
|
||||
confidence = 0.95
|
||||
caveats = "Keyword-based matching can include false positives; validate with full URL context."
|
||||
else:
|
||||
guidance = (
|
||||
f"No policy-violation hits found in current scope "
|
||||
f"({datasets_scanned} dataset(s) scanned)."
|
||||
)
|
||||
confidence = 0.9
|
||||
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
|
||||
|
||||
response = AssistResponseModel(
|
||||
guidance=guidance,
|
||||
confidence=confidence,
|
||||
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
|
||||
suggested_filters=[
|
||||
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
|
||||
"username != null",
|
||||
"hostname != null",
|
||||
],
|
||||
caveats=caveats,
|
||||
reasoning=(
|
||||
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
|
||||
if _is_policy_domain_query(request.query)
|
||||
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
|
||||
),
|
||||
sans_references=["SANS FOR508", "SANS SEC504"],
|
||||
model_used="execution:keyword_scanner",
|
||||
node_used="local",
|
||||
latency_ms=latency_ms,
|
||||
execution=exec_payload,
|
||||
)
|
||||
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db,
|
||||
conv_id,
|
||||
request,
|
||||
AgentResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
),
|
||||
)
|
||||
response.conversation_id = conv_id
|
||||
|
||||
return response
|
||||
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
@@ -97,6 +335,7 @@ async def agent_assist(
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
learning_mode=request.learning_mode,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
@@ -129,6 +368,7 @@ async def agent_assist(
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
execution=None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
@@ -208,7 +448,7 @@ async def list_models():
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
# Conversation persistence
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
@@ -263,3 +503,4 @@ async def _persist_conversation(
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
|
||||
|
||||
@@ -290,6 +290,47 @@ async def get_knowledge_graph(
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
<<<<<<< HEAD
|
||||
"""Submit a new job to the queue.
|
||||
|
||||
Job types: triage, host_profile, report, anomaly, query
|
||||
Params vary by type (e.g., dataset_id, hunt_id, question, mode).
|
||||
"""
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
|
||||
try:
|
||||
jt = JobType(job_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||
)
|
||||
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
job = job_queue.submit(jt, **params)
|
||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||
|
||||
|
||||
# --- Load balancer status ---
|
||||
|
||||
@router.get("/lb/status")
|
||||
async def lb_status():
|
||||
"""Get load balancer status for both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
return lb.get_status()
|
||||
|
||||
|
||||
@router.post("/lb/check")
|
||||
async def lb_health_check():
|
||||
"""Force a health check of both nodes."""
|
||||
from app.services.load_balancer import lb
|
||||
await lb.check_health()
|
||||
return lb.get_status()
|
||||
=======
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
hashed_password=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
if not user or not user.hashed_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
if not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import ProcessingTask
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
@@ -18,15 +19,20 @@ from app.services.normalizer import (
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
from app.services.artifact_classifier import classify_artifact, get_artifact_category
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
from app.services.host_inventory import inventory_cache
|
||||
from app.services.scanner import keyword_scan_cache
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
# -- Response models --
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
artifact_type: str | None = None
|
||||
processing_status: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
artifact_type: str | None = None
|
||||
processing_status: str
|
||||
jobs_queued: list[str]
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# -- Routes --
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
@@ -87,7 +98,7 @@ async def upload_dataset(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
@@ -136,7 +147,12 @@ async def upload_dataset(
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
# Classify artifact type from column headers
|
||||
artifact_type = classify_artifact(columns)
|
||||
artifact_category = get_artifact_category(artifact_type)
|
||||
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
|
||||
|
||||
# Store in DB with processing_status = "processing"
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
@@ -152,6 +168,8 @@ async def upload_dataset(
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
@@ -162,9 +180,88 @@ async def upload_dataset(
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
|
||||
f"artifact={artifact_type}"
|
||||
)
|
||||
|
||||
# -- Queue full processing pipeline --
|
||||
jobs_queued = []
|
||||
|
||||
task_rows: list[ProcessingTask] = []
|
||||
|
||||
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||
jobs_queued.append("triage")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=triage_job.id,
|
||||
stage="triage",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 2. Anomaly detection (embedding-based outlier detection)
|
||||
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||
jobs_queued.append("anomaly")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=anomaly_job.id,
|
||||
stage="anomaly",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 3. AUP keyword scan
|
||||
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||
jobs_queued.append("keyword_scan")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=kw_job.id,
|
||||
stage="keyword_scan",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 4. IOC extraction
|
||||
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||
jobs_queued.append("ioc_extract")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=ioc_job.id,
|
||||
stage="ioc_extract",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 5. Host inventory (network map) - requires hunt_id
|
||||
if hunt_id:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
jobs_queued.append("host_inventory")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=inv_job.id,
|
||||
stage="host_inventory",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
if task_rows:
|
||||
db.add_all(task_rows)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
@@ -173,7 +270,10 @@ async def upload_dataset(
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
jobs_queued=jobs_queued,
|
||||
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
|
||||
)
|
||||
|
||||
|
||||
@@ -208,6 +308,8 @@ async def list_datasets(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -244,6 +346,8 @@ async def get_dataset(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -292,6 +396,7 @@ async def delete_dataset(
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
|
||||
|
||||
@@ -8,16 +8,15 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
from app.db.models import Hunt, Dataset, ProcessingTask
|
||||
from app.services.job_queue import job_queue
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
class HuntProgressResponse(BaseModel):
|
||||
hunt_id: str
|
||||
status: str
|
||||
progress_percent: float
|
||||
dataset_total: int
|
||||
dataset_completed: int
|
||||
dataset_processing: int
|
||||
dataset_errors: int
|
||||
active_jobs: int
|
||||
queued_jobs: int
|
||||
network_status: str
|
||||
stages: dict
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
hunt = await db.get(Hunt, hunt_id)
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
ds_rows = await db.execute(
|
||||
select(Dataset.id, Dataset.processing_status)
|
||||
.where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_rows.all()
|
||||
dataset_ids = {row[0] for row in datasets}
|
||||
|
||||
dataset_total = len(datasets)
|
||||
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||
|
||||
jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
task_rows = await db.execute(
|
||||
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||
.where(ProcessingTask.hunt_id == hunt_id)
|
||||
)
|
||||
tasks = task_rows.all()
|
||||
|
||||
task_total = len(tasks)
|
||||
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||
|
||||
active_jobs = max(active_jobs_mem, task_running)
|
||||
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||
|
||||
stage_rollup: dict[str, dict] = {}
|
||||
for stage, status, progress in tasks:
|
||||
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||
bucket["total"] += 1
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
bucket["done"] += 1
|
||||
elif status == "running":
|
||||
bucket["running"] += 1
|
||||
elif status == "queued":
|
||||
bucket["queued"] += 1
|
||||
bucket["progress_sum"] += float(progress or 0.0)
|
||||
|
||||
for stage_name, bucket in stage_rollup.items():
|
||||
total = max(1, bucket["total"])
|
||||
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
network_status = "ready"
|
||||
network_ratio = 1.0
|
||||
elif inventory_cache.is_building(hunt_id):
|
||||
network_status = "building"
|
||||
network_ratio = 0.5
|
||||
else:
|
||||
network_status = "none"
|
||||
network_ratio = 0.0
|
||||
|
||||
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
if task_ratio is None:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
else:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
|
||||
status = "ready"
|
||||
if dataset_total == 0:
|
||||
status = "idle"
|
||||
elif progress_percent < 100:
|
||||
status = "processing"
|
||||
|
||||
stages = {
|
||||
"datasets": {
|
||||
"total": dataset_total,
|
||||
"completed": dataset_completed,
|
||||
"processing": dataset_processing,
|
||||
"errors": dataset_errors,
|
||||
"percent": round(dataset_ratio * 100.0, 1),
|
||||
},
|
||||
"network": {
|
||||
"status": network_status,
|
||||
"percent": round(network_ratio * 100.0, 1),
|
||||
},
|
||||
"jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
"task_total": task_total,
|
||||
"task_done": task_done,
|
||||
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||
},
|
||||
"task_stages": stage_rollup,
|
||||
}
|
||||
|
||||
return HuntProgressResponse(
|
||||
hunt_id=hunt_id,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
dataset_total=dataset_total,
|
||||
dataset_completed=dataset_completed,
|
||||
dataset_processing=dataset_processing,
|
||||
dataset_errors=dataset_errors,
|
||||
active_jobs=active_jobs,
|
||||
queued_jobs=queued_jobs,
|
||||
network_status=network_status,
|
||||
stages=stages,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
dataset_ids: list[str] | None = None
|
||||
theme_ids: list[str] | None = None
|
||||
scan_hunts: bool = False
|
||||
scan_annotations: bool = False
|
||||
scan_messages: bool = False
|
||||
prefer_cache: bool = True
|
||||
force_rescan: bool = False
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
cache_used: bool = False
|
||||
cache_status: str = "miss"
|
||||
cached_at: str | None = None
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||
hits: list[dict] = []
|
||||
total_rows = 0
|
||||
cached_at: str | None = None
|
||||
|
||||
for entry in entries:
|
||||
result = entry["result"]
|
||||
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||
if entry.get("built_at"):
|
||||
if not cached_at or entry["built_at"] > cached_at:
|
||||
cached_at = entry["built_at"]
|
||||
for h in result.get("hits", []):
|
||||
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||
continue
|
||||
hits.append(h)
|
||||
|
||||
return {
|
||||
"total_hits": len(hits),
|
||||
"hits": hits,
|
||||
"rows_scanned": total_rows,
|
||||
"cached_at": cached_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
keyword_scan_cache.clear()
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
keyword_scan_cache.clear()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||
return {
|
||||
"total_hits": 0,
|
||||
"hits": [],
|
||||
"themes_scanned": 0,
|
||||
"keywords_scanned": 0,
|
||||
"rows_scanned": 0,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
can_use_cache = (
|
||||
body.prefer_cache
|
||||
and not body.force_rescan
|
||||
and bool(body.dataset_ids)
|
||||
and not body.scan_hunts
|
||||
and not body.scan_annotations
|
||||
and not body.scan_messages
|
||||
)
|
||||
|
||||
if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
if missing:
|
||||
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + [{"result": partial, "built_at": None}],
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
@@ -251,7 +326,22 @@ async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if entry is not None:
|
||||
result = entry.result
|
||||
return {
|
||||
**result,
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": entry.built_at,
|
||||
}
|
||||
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
146
backend/app/api/routes/mitre.py
Normal file
146
backend/app/api/routes/mitre.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""API routes for MITRE ATT&CK coverage visualization."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
|
||||
|
||||
# Canonical MITRE ATT&CK tactics in kill-chain order
|
||||
TACTICS = [
|
||||
"Reconnaissance", "Resource Development", "Initial Access",
|
||||
"Execution", "Persistence", "Privilege Escalation",
|
||||
"Defense Evasion", "Credential Access", "Discovery",
|
||||
"Lateral Movement", "Collection", "Command and Control",
|
||||
"Exfiltration", "Impact",
|
||||
]
|
||||
|
||||
# Simplified technique-to-tactic mapping (top techniques)
|
||||
TECHNIQUE_TACTIC: dict[str, str] = {
|
||||
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
|
||||
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
|
||||
"T1053": "Persistence", "T1053.005": "Persistence",
|
||||
"T1547": "Persistence", "T1547.001": "Persistence",
|
||||
"T1543": "Persistence", "T1543.003": "Persistence",
|
||||
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
|
||||
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
|
||||
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
|
||||
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
|
||||
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
|
||||
"T1070.004": "Defense Evasion",
|
||||
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
|
||||
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
|
||||
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
|
||||
"T1003": "Credential Access", "T1003.001": "Credential Access",
|
||||
"T1110": "Credential Access", "T1558": "Credential Access",
|
||||
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
|
||||
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
|
||||
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
|
||||
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
|
||||
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
|
||||
"T1570": "Lateral Movement",
|
||||
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
|
||||
"T1071": "Command and Control", "T1071.001": "Command and Control",
|
||||
"T1105": "Command and Control", "T1572": "Command and Control",
|
||||
"T1095": "Command and Control",
|
||||
"T1048": "Exfiltration", "T1041": "Exfiltration",
|
||||
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
|
||||
"T1566": "Initial Access", "T1566.001": "Initial Access",
|
||||
"T1566.002": "Initial Access",
|
||||
"T1190": "Initial Access", "T1133": "Initial Access",
|
||||
"T1195": "Initial Access", "T1195.002": "Initial Access",
|
||||
}
|
||||
|
||||
|
||||
def _get_tactic(technique_id: str) -> str:
|
||||
"""Map a technique ID to its tactic."""
|
||||
tech = technique_id.strip().upper()
|
||||
if tech in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[tech]
|
||||
# Try parent technique
|
||||
if "." in tech:
|
||||
parent = tech.split(".")[0]
|
||||
if parent in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[parent]
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@router.get("/coverage")
|
||||
async def get_mitre_coverage(
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
|
||||
|
||||
techniques: dict[str, dict] = {}
|
||||
|
||||
# Collect from triage results
|
||||
triage_q = select(TriageResult)
|
||||
if hunt_id:
|
||||
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
result = await db.execute(triage_q.limit(500))
|
||||
for t in result.scalars().all():
|
||||
for tech in (t.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
|
||||
|
||||
# Collect from host profiles
|
||||
profile_q = select(HostProfile)
|
||||
if hunt_id:
|
||||
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
|
||||
result = await db.execute(profile_q.limit(200))
|
||||
for p in result.scalars().all():
|
||||
for tech in (p.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
|
||||
|
||||
# Collect from hypotheses
|
||||
hyp_q = select(Hypothesis)
|
||||
if hunt_id:
|
||||
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
|
||||
result = await db.execute(hyp_q.limit(200))
|
||||
for h in result.scalars().all():
|
||||
tech = h.mitre_technique
|
||||
if tech:
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
|
||||
|
||||
# Build tactic-grouped response
|
||||
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
|
||||
tactic_groups["Unknown"] = []
|
||||
for tech in techniques.values():
|
||||
tactic = tech["tactic"]
|
||||
if tactic not in tactic_groups:
|
||||
tactic_groups[tactic] = []
|
||||
tactic_groups[tactic].append(tech)
|
||||
|
||||
total_techniques = len(techniques)
|
||||
total_detections = sum(t["count"] for t in techniques.values())
|
||||
|
||||
return {
|
||||
"tactics": TACTICS,
|
||||
"technique_count": total_techniques,
|
||||
"detection_count": total_detections,
|
||||
"tactic_coverage": {
|
||||
t: {"techniques": techs, "count": len(techs)}
|
||||
for t, techs in tactic_groups.items()
|
||||
if techs
|
||||
},
|
||||
"all_techniques": list(techniques.values()),
|
||||
}
|
||||
@@ -1,19 +1,193 @@
|
||||
<<<<<<< HEAD
|
||||
"""Network topology API - host inventory endpoint with background caching."""
|
||||
=======
|
||||
"""API routes for Network Picture — deduplicated host inventory."""
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
<<<<<<< HEAD
|
||||
from fastapi.responses import JSONResponse
|
||||
=======
|
||||
from pydantic import BaseModel, Field
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
<<<<<<< HEAD
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
=======
|
||||
from app.services.network_inventory import build_network_picture
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
force: bool = Query(False, description="Force rebuild, ignoring cache"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a deduplicated host inventory for the hunt.
|
||||
|
||||
Returns instantly from cache if available (pre-built after upload or on startup).
|
||||
If cache is cold, triggers a background build and returns 202 so the
|
||||
frontend can poll /inventory-status and re-request when ready.
|
||||
"""
|
||||
# Force rebuild: invalidate cache, queue background job, return 202
|
||||
if force:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Rebuild queued"},
|
||||
)
|
||||
|
||||
# Try cache first
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is not None:
|
||||
logger.info(f"Serving cached host inventory for {hunt_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss: trigger background build instead of blocking for 90+ seconds
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
logger.info(f"Cache miss for {hunt_id}, triggering background build")
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Inventory is being built in the background"},
|
||||
)
|
||||
|
||||
|
||||
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||
return {
|
||||
"stats": inv.get("stats", {}),
|
||||
"top_hosts": [
|
||||
{
|
||||
"id": h.get("id"),
|
||||
"hostname": h.get("hostname"),
|
||||
"row_count": h.get("row_count", 0),
|
||||
"ip_count": len(h.get("ips", [])),
|
||||
"user_count": len(h.get("users", [])),
|
||||
}
|
||||
for h in top_hosts
|
||||
],
|
||||
"top_edges": top_edges,
|
||||
}
|
||||
|
||||
|
||||
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
|
||||
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||
|
||||
if node_id:
|
||||
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||
ids = {node_id}
|
||||
for c in rel_edges:
|
||||
ids.add(c.get("source"))
|
||||
ids.add(c.get("target"))
|
||||
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||
else:
|
||||
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||
allowed = {h.get("id") for h in rel_hosts}
|
||||
rel_edges = [
|
||||
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||
if c.get("source") in allowed and c.get("target") in allowed
|
||||
][:max_edges]
|
||||
|
||||
return {
|
||||
"hosts": rel_hosts,
|
||||
"connections": rel_edges,
|
||||
"stats": {
|
||||
**inv.get("stats", {}),
|
||||
"subgraph_hosts": len(rel_hosts),
|
||||
"subgraph_connections": len(rel_edges),
|
||||
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_inventory_summary(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
top_n: int = Query(20, ge=1, le=200),
|
||||
):
|
||||
"""Return a lightweight summary view for large hunts."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_summary(cached, top_n=top_n)
|
||||
|
||||
|
||||
@router.get("/subgraph")
|
||||
async def get_inventory_subgraph(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
node_id: str | None = Query(None, description="Optional focal node"),
|
||||
max_hosts: int = Query(200, ge=1, le=5000),
|
||||
max_edges: int = Query(1500, ge=1, le=20000),
|
||||
):
|
||||
"""Return a bounded subgraph for scale-safe rendering."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||
|
||||
|
||||
@router.get("/inventory-status")
|
||||
async def get_inventory_status(
|
||||
hunt_id: str = Query(..., description="Hunt ID to check"),
|
||||
):
|
||||
"""Check whether pre-computed host inventory is ready for a hunt.
|
||||
|
||||
Returns: { status: "ready" | "building" | "none" }
|
||||
"""
|
||||
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
|
||||
|
||||
|
||||
@router.post("/rebuild-inventory")
|
||||
async def trigger_rebuild(
|
||||
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
|
||||
):
|
||||
"""Trigger a background rebuild of the host inventory cache."""
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
=======
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -67,3 +241,4 @@ async def get_network_picture(
|
||||
|
||||
result = await build_network_picture(db, hunt_id)
|
||||
return result
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
217
backend/app/api/routes/playbooks.py
Normal file
217
backend/app/api/routes/playbooks.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""API routes for investigation playbooks."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Playbook, PlaybookStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
|
||||
|
||||
|
||||
# -- Request / Response schemas ---
|
||||
|
||||
class StepCreate(BaseModel):
|
||||
title: str
|
||||
description: str | None = None
|
||||
step_type: str = "manual"
|
||||
target_route: str | None = None
|
||||
|
||||
class PlaybookCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
hunt_id: str | None = None
|
||||
is_template: bool = False
|
||||
steps: list[StepCreate] = []
|
||||
|
||||
class PlaybookUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
class StepUpdate(BaseModel):
|
||||
is_completed: bool | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
# -- Default investigation templates ---
|
||||
|
||||
DEFAULT_TEMPLATES = [
|
||||
{
|
||||
"name": "Standard Threat Hunt",
|
||||
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
|
||||
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
|
||||
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
|
||||
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
|
||||
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
|
||||
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Incident Response Triage",
|
||||
"description": "Fast-track workflow for active incident response.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
|
||||
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -- Routes ---
|
||||
|
||||
@router.get("")
|
||||
async def list_playbooks(
|
||||
include_templates: bool = True,
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Playbook)
|
||||
if hunt_id:
|
||||
q = q.where(Playbook.hunt_id == hunt_id)
|
||||
if not include_templates:
|
||||
q = q.where(Playbook.is_template == False)
|
||||
q = q.order_by(Playbook.created_at.desc())
|
||||
result = await db.execute(q.limit(100))
|
||||
playbooks = result.scalars().all()
|
||||
|
||||
return {"playbooks": [
|
||||
{
|
||||
"id": p.id, "name": p.name, "description": p.description,
|
||||
"is_template": p.is_template, "hunt_id": p.hunt_id,
|
||||
"status": p.status,
|
||||
"total_steps": len(p.steps),
|
||||
"completed_steps": sum(1 for s in p.steps if s.is_completed),
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
}
|
||||
for p in playbooks
|
||||
]}
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
async def get_templates():
|
||||
"""Return built-in investigation templates."""
|
||||
return {"templates": DEFAULT_TEMPLATES}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
|
||||
pb = Playbook(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
hunt_id=body.hunt_id,
|
||||
is_template=body.is_template,
|
||||
)
|
||||
db.add(pb)
|
||||
await db.flush()
|
||||
|
||||
created_steps = []
|
||||
for i, step in enumerate(body.steps):
|
||||
s = PlaybookStep(
|
||||
playbook_id=pb.id,
|
||||
order_index=i,
|
||||
title=step.title,
|
||||
description=step.description,
|
||||
step_type=step.step_type,
|
||||
target_route=step.target_route,
|
||||
)
|
||||
db.add(s)
|
||||
created_steps.append(s)
|
||||
|
||||
await db.flush()
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
|
||||
"steps": [
|
||||
{"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route, "is_completed": False}
|
||||
for s in created_steps
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{playbook_id}")
|
||||
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
|
||||
"status": pb.status,
|
||||
"created_at": pb.created_at.isoformat() if pb.created_at else None,
|
||||
"steps": [
|
||||
{
|
||||
"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route,
|
||||
"is_completed": s.is_completed,
|
||||
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
|
||||
"notes": s.notes,
|
||||
}
|
||||
for s in sorted(pb.steps, key=lambda x: x.order_index)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{playbook_id}")
|
||||
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
if body.name is not None:
|
||||
pb.name = body.name
|
||||
if body.description is not None:
|
||||
pb.description = body.description
|
||||
if body.status is not None:
|
||||
pb.status = body.status
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{playbook_id}")
|
||||
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
await db.delete(pb)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}")
|
||||
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
|
||||
step = result.scalar_one_or_none()
|
||||
if not step:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
if body.is_completed is not None:
|
||||
step.is_completed = body.is_completed
|
||||
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
|
||||
if body.notes is not None:
|
||||
step.notes = body.notes
|
||||
return {"status": "updated", "is_completed": step.is_completed}
|
||||
|
||||
164
backend/app/api/routes/saved_searches.py
Normal file
164
backend/app/api/routes/saved_searches.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""API routes for saved searches and bookmarked queries."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import SavedSearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
|
||||
|
||||
|
||||
class SearchCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
|
||||
query_params: dict
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
class SearchUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
query_params: dict | None = None
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_searches(
|
||||
search_type: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
|
||||
if search_type:
|
||||
q = q.where(SavedSearch.search_type == search_type)
|
||||
result = await db.execute(q.limit(100))
|
||||
searches = result.scalars().all()
|
||||
return {"searches": [
|
||||
{
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in searches
|
||||
]}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
|
||||
s = SavedSearch(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
search_type=body.search_type,
|
||||
query_params=body.query_params,
|
||||
threshold=body.threshold,
|
||||
)
|
||||
db.add(s)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": s.id, "name": s.name, "search_type": s.search_type,
|
||||
"query_params": s.query_params, "threshold": s.threshold,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{search_id}")
|
||||
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
return {
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{search_id}")
|
||||
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
if body.name is not None:
|
||||
s.name = body.name
|
||||
if body.description is not None:
|
||||
s.description = body.description
|
||||
if body.query_params is not None:
|
||||
s.query_params = body.query_params
|
||||
if body.threshold is not None:
|
||||
s.threshold = body.threshold
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{search_id}")
|
||||
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
await db.delete(s)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/{search_id}/run")
|
||||
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Execute a saved search and return results with delta from last run."""
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
|
||||
previous_count = s.last_result_count or 0
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
if s.search_type == "ioc_search":
|
||||
from app.db.models import EnrichmentResult
|
||||
ioc_value = s.query_params.get("ioc_value", "")
|
||||
if ioc_value:
|
||||
q = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value.contains(ioc_value)
|
||||
)
|
||||
res = await db.execute(q.limit(100))
|
||||
for er in res.scalars().all():
|
||||
results.append({
|
||||
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
|
||||
"source": er.source, "verdict": er.verdict,
|
||||
})
|
||||
count = len(results)
|
||||
|
||||
elif s.search_type == "keyword_scan":
|
||||
from app.db.models import KeywordTheme
|
||||
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
|
||||
themes = res.scalars().all()
|
||||
count = sum(len(t.keywords) for t in themes)
|
||||
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
|
||||
|
||||
# Update last run metadata
|
||||
s.last_run_at = datetime.now(timezone.utc)
|
||||
s.last_result_count = count
|
||||
|
||||
delta = count - previous_count
|
||||
|
||||
return {
|
||||
"search_id": s.id, "search_name": s.name,
|
||||
"search_type": s.search_type,
|
||||
"result_count": count,
|
||||
"previous_count": previous_count,
|
||||
"delta": delta,
|
||||
"results": results[:50],
|
||||
}
|
||||
184
backend/app/api/routes/stix_export.py
Normal file
184
backend/app/api/routes/stix_export.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""STIX 2.1 export endpoint.
|
||||
|
||||
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
|
||||
STIX 2.1 Bundle JSON download. No external dependencies required we
|
||||
build the JSON directly following the OASIS STIX 2.1 spec.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
|
||||
EnrichmentResult, HuntReport,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/export", tags=["export"])
|
||||
|
||||
STIX_SPEC_VERSION = "2.1"
|
||||
|
||||
|
||||
def _stix_id(stype: str) -> str:
|
||||
return f"{stype}--{uuid.uuid4()}"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
|
||||
def _build_identity(hunt_name: str) -> dict:
|
||||
return {
|
||||
"type": "identity",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("identity"),
|
||||
"created": _now_iso(),
|
||||
"modified": _now_iso(),
|
||||
"name": f"ThreatHunt - {hunt_name}",
|
||||
"identity_class": "system",
|
||||
}
|
||||
|
||||
|
||||
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
|
||||
pattern_map = {
|
||||
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
|
||||
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
|
||||
"domain": f"[domain-name:value = '{ioc_value}']",
|
||||
"url": f"[url:value = '{ioc_value}']",
|
||||
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
|
||||
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
|
||||
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
|
||||
"email": f"[email-addr:value = '{ioc_value}']",
|
||||
}
|
||||
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "indicator",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("indicator"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": f"{ioc_type}: {ioc_value}",
|
||||
"pattern": pattern,
|
||||
"pattern_type": "stix",
|
||||
"valid_from": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": [verdict or "suspicious"],
|
||||
}
|
||||
|
||||
|
||||
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "attack-pattern",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("attack-pattern"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": technique_id,
|
||||
"created_by_ref": identity_id,
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"external_id": technique_id,
|
||||
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "report",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("report"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": hyp.title,
|
||||
"description": hyp.description or "",
|
||||
"published": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": ["threat-hunt-hypothesis"],
|
||||
"object_refs": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stix/{hunt_id}")
|
||||
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
|
||||
# Fetch hunt
|
||||
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(404, "Hunt not found")
|
||||
|
||||
identity = _build_identity(hunt.name)
|
||||
objects: list[dict] = [identity]
|
||||
seen_techniques: set[str] = set()
|
||||
seen_iocs: set[str] = set()
|
||||
|
||||
# Gather IOCs from enrichment results for hunt's datasets
|
||||
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
|
||||
ds_ids = [r[0] for r in datasets_q.all()]
|
||||
|
||||
if ds_ids:
|
||||
enrichments = (await db.execute(
|
||||
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for e in enrichments:
|
||||
key = f"{e.ioc_type}:{e.ioc_value}"
|
||||
if key not in seen_iocs:
|
||||
seen_iocs.add(key)
|
||||
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
|
||||
|
||||
# Gather techniques from triage results
|
||||
triages = (await db.execute(
|
||||
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for t in triages:
|
||||
for tech in (t.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather techniques from host profiles
|
||||
profiles = (await db.execute(
|
||||
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for p in profiles:
|
||||
for tech in (p.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather hypotheses
|
||||
hypos = (await db.execute(
|
||||
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for h in hypos:
|
||||
objects.append(_hypothesis_to_report(h, identity["id"]))
|
||||
if h.mitre_technique and h.mitre_technique not in seen_techniques:
|
||||
seen_techniques.add(h.mitre_technique)
|
||||
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
|
||||
|
||||
bundle = {
|
||||
"type": "bundle",
|
||||
"id": _stix_id("bundle"),
|
||||
"objects": objects,
|
||||
}
|
||||
|
||||
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
|
||||
return Response(
|
||||
content=json.dumps(bundle, indent=2),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
128
backend/app/api/routes/timeline.py
Normal file
128
backend/app/api/routes/timeline.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""API routes for forensic timeline visualization."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Dataset, DatasetRow, Hunt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
|
||||
|
||||
|
||||
def _parse_timestamp(val: str | None) -> str | None:
|
||||
"""Try to parse a timestamp string, return ISO format or None."""
|
||||
if not val:
|
||||
return None
|
||||
val = str(val).strip()
|
||||
if not val:
|
||||
return None
|
||||
# Try common formats
|
||||
for fmt in [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
|
||||
]:
|
||||
try:
|
||||
return datetime.strptime(val, fmt).isoformat() + "Z"
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
# Columns likely to contain timestamps
|
||||
TIME_COLUMNS = {
|
||||
"timestamp", "time", "datetime", "date", "created", "modified",
|
||||
"eventtime", "event_time", "start_time", "end_time",
|
||||
"lastmodified", "last_modified", "created_at", "updated_at",
|
||||
"mtime", "atime", "ctime", "btime",
|
||||
"timecreated", "timegenerated", "sourcetime",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/hunt/{hunt_id}")
|
||||
async def get_hunt_timeline(
|
||||
hunt_id: str,
|
||||
limit: int = 2000,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a timeline of events across all datasets in a hunt."""
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
if not datasets:
|
||||
return {"hunt_id": hunt_id, "events": [], "datasets": []}
|
||||
|
||||
events = []
|
||||
dataset_info = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
dataset_info.append({
|
||||
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
|
||||
"row_count": ds.row_count,
|
||||
})
|
||||
|
||||
# Find time columns for this dataset
|
||||
schema = ds.column_schema or {}
|
||||
time_cols = []
|
||||
for col in (ds.normalized_columns or {}).values():
|
||||
if col.lower() in TIME_COLUMNS:
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
for col in schema:
|
||||
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
continue
|
||||
|
||||
# Fetch rows
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(limit // max(len(datasets), 1))
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
ts = None
|
||||
for tc in time_cols:
|
||||
ts = _parse_timestamp(data.get(tc))
|
||||
if ts:
|
||||
break
|
||||
if ts:
|
||||
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
|
||||
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
|
||||
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
|
||||
events.append({
|
||||
"timestamp": ts,
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": ds.name,
|
||||
"artifact_type": artifact_type,
|
||||
"row_index": r.row_index,
|
||||
"hostname": str(hostname)[:128],
|
||||
"process": str(process)[:128],
|
||||
"summary": str(summary)[:256],
|
||||
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
|
||||
})
|
||||
|
||||
# Sort by timestamp
|
||||
events.sort(key=lambda e: e["timestamp"])
|
||||
|
||||
return {
|
||||
"hunt_id": hunt_id,
|
||||
"hunt_name": hunt.name,
|
||||
"event_count": len(events),
|
||||
"datasets": dataset_info,
|
||||
"events": events[:limit],
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Application configuration — single source of truth for all settings.
|
||||
"""Application configuration - single source of truth for all settings.
|
||||
|
||||
Loads from environment variables with sensible defaults for local dev.
|
||||
"""
|
||||
@@ -13,12 +13,12 @@ from pydantic import Field
|
||||
class AppConfig(BaseSettings):
|
||||
"""Central configuration for the entire ThreatHunt application."""
|
||||
|
||||
# ── General ────────────────────────────────────────────────────────
|
||||
# -- General --------------------------------------------------------
|
||||
APP_NAME: str = "ThreatHunt"
|
||||
APP_VERSION: str = "0.4.0"
|
||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# ── Database ───────────────────────────────────────────────────────
|
||||
# -- Database -------------------------------------------------------
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite+aiosqlite:///./threathunt.db",
|
||||
description="Async SQLAlchemy database URL. "
|
||||
@@ -26,17 +26,17 @@ class AppConfig(BaseSettings):
|
||||
"postgresql+asyncpg://user:pass@host/db for production.",
|
||||
)
|
||||
|
||||
# ── CORS ───────────────────────────────────────────────────────────
|
||||
# -- CORS -----------------------------------------------------------
|
||||
ALLOWED_ORIGINS: str = Field(
|
||||
default="http://localhost:3000,http://localhost:8000",
|
||||
description="Comma-separated list of allowed CORS origins",
|
||||
)
|
||||
|
||||
# ── File uploads ───────────────────────────────────────────────────
|
||||
# -- File uploads ---------------------------------------------------
|
||||
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
||||
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
||||
|
||||
# ── LLM Cluster — Wile & Roadrunner ────────────────────────────────
|
||||
# -- LLM Cluster - Wile & Roadrunner --------------------------------
|
||||
OPENWEBUI_URL: str = Field(
|
||||
default="https://ai.guapo613.beer",
|
||||
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
||||
@@ -58,7 +58,7 @@ class AppConfig(BaseSettings):
|
||||
default=11434, description="Ollama port on Roadrunner"
|
||||
)
|
||||
|
||||
# ── LLM Routing defaults ──────────────────────────────────────────
|
||||
# -- LLM Routing defaults ------------------------------------------
|
||||
DEFAULT_FAST_MODEL: str = Field(
|
||||
default="llama3.1:latest",
|
||||
description="Default model for quick chat / simple queries",
|
||||
@@ -80,18 +80,18 @@ class AppConfig(BaseSettings):
|
||||
description="Default embedding model",
|
||||
)
|
||||
|
||||
# ── Agent behaviour ───────────────────────────────────────────────
|
||||
# -- Agent behaviour ------------------------------------------------
|
||||
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
||||
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
||||
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
||||
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
||||
|
||||
# ── Enrichment API keys ───────────────────────────────────────────
|
||||
# -- Enrichment API keys --------------------------------------------
|
||||
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
||||
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
||||
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────
|
||||
# -- Auth -----------------------------------------------------------
|
||||
JWT_SECRET: str = Field(
|
||||
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
||||
description="Secret for JWT signing",
|
||||
@@ -99,6 +99,73 @@ class AppConfig(BaseSettings):
|
||||
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
||||
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
||||
|
||||
# -- Triage settings ------------------------------------------------
|
||||
TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch")
|
||||
TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field(
|
||||
default=200, description="Stop triage after this many suspicious rows"
|
||||
)
|
||||
TRIAGE_ESCALATION_THRESHOLD: float = Field(
|
||||
default=5.0, description="Risk score threshold for escalation counting"
|
||||
)
|
||||
|
||||
# -- Host profiler settings -----------------------------------------
|
||||
HOST_PROFILE_CONCURRENCY: int = Field(
|
||||
default=3, description="Max concurrent host profile LLM calls"
|
||||
)
|
||||
|
||||
# -- Scanner settings -----------------------------------------------
|
||||
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||
)
|
||||
|
||||
# -- Job queue settings ----------------------------------------------
|
||||
JOB_QUEUE_MAX_BACKLOG: int = Field(
|
||||
default=2000, description="Soft cap for queued background jobs"
|
||||
)
|
||||
JOB_QUEUE_RETAIN_COMPLETED: int = Field(
|
||||
default=3000, description="Maximum completed/failed jobs to retain in memory"
|
||||
)
|
||||
JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(
|
||||
default=60, description="How often to run in-memory job cleanup"
|
||||
)
|
||||
JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(
|
||||
default=3600, description="Age threshold for in-memory completed job cleanup"
|
||||
)
|
||||
|
||||
# -- Startup throttling ------------------------------------------------
|
||||
STARTUP_WARMUP_MAX_HUNTS: int = Field(
|
||||
default=5, description="Max hunts to warm inventory cache for at startup"
|
||||
)
|
||||
STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||
)
|
||||
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
|
||||
default=True,
|
||||
description="Mark stale queued/running processing tasks as failed on startup",
|
||||
)
|
||||
|
||||
# -- Network API scale guards -----------------------------------------
|
||||
NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(
|
||||
default=400, description="Hard cap for hosts returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=5000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
|
||||
default=120000,
|
||||
description="Max unique connection tuples retained during host inventory build",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
@property
|
||||
@@ -119,3 +186,4 @@ class AppConfig(BaseSettings):
|
||||
|
||||
|
||||
settings = AppConfig()
|
||||
|
||||
|
||||
@@ -21,9 +21,14 @@ _engine_kwargs: dict = dict(
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
||||
_engine_kwargs["pool_size"] = 1
|
||||
_engine_kwargs["max_overflow"] = 0
|
||||
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||
# NullPool: each session gets its own connection.
|
||||
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||
from sqlalchemy.pool import NullPool
|
||||
_engine_kwargs["poolclass"] = NullPool
|
||||
else:
|
||||
_engine_kwargs["pool_size"] = 5
|
||||
_engine_kwargs["max_overflow"] = 10
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
@@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.execute("PRAGMA busy_timeout=30000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
@@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker(
|
||||
)
|
||||
|
||||
|
||||
# Alias expected by other modules
|
||||
async_session = async_session_factory
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
@@ -83,5 +92,5 @@ async def init_db() -> None:
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
"""Dispose of the engine on shutdown."""
|
||||
await engine.dispose()
|
||||
@@ -1,4 +1,4 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, and users.
|
||||
@@ -44,6 +44,7 @@ class User(Base):
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
@@ -544,3 +545,116 @@ class PlaybookRun(Base):
|
||||
Index("ix_playbook_runs_hunt", "hunt_id"),
|
||||
Index("ix_playbook_runs_status", "status"),
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
anomaly_score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
distance_from_centroid: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
|
||||
# -- Persistent Processing Tasks (Phase 2) ---
|
||||
|
||||
class ProcessingTask(Base):
|
||||
__tablename__ = "processing_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
|
||||
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
|
||||
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
|
||||
)
|
||||
|
||||
|
||||
# -- Playbook / Investigation Templates (Feature 3) ---
|
||||
|
||||
class Playbook(Base):
|
||||
__tablename__ = "playbooks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_by: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
is_template: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
steps: Mapped[list["PlaybookStep"]] = relationship(
|
||||
back_populates="playbook", lazy="selectin", cascade="all, delete-orphan",
|
||||
order_by="PlaybookStep.order_index",
|
||||
)
|
||||
|
||||
|
||||
class PlaybookStep(Base):
|
||||
__tablename__ = "playbook_steps"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
playbook_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
order_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
step_type: Mapped[str] = mapped_column(String(32), default="manual")
|
||||
target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
is_completed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
playbook: Mapped["Playbook"] = relationship(back_populates="steps")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_playbook_steps_playbook", "playbook_id"),
|
||||
)
|
||||
|
||||
|
||||
# -- Saved Searches (Feature 5) ---
|
||||
|
||||
class SavedSearch(Base):
|
||||
__tablename__ = "saved_searches"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
search_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
query_params: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
created_by: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_saved_searches_type", "search_type"),
|
||||
)
|
||||
=======
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
@@ -22,10 +22,18 @@ from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.network import router as network_router
|
||||
<<<<<<< HEAD
|
||||
from app.api.routes.mitre import router as mitre_router
|
||||
from app.api.routes.timeline import router as timeline_router
|
||||
from app.api.routes.playbooks import router as playbooks_router
|
||||
from app.api.routes.saved_searches import router as searches_router
|
||||
from app.api.routes.stix_export import router as stix_router
|
||||
=======
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.cases import router as cases_router
|
||||
from app.api.routes.alerts import router as alerts_router
|
||||
from app.api.routes.notebooks import router as notebooks_router
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,8 +50,101 @@ async def lifespan(app: FastAPI):
|
||||
async with async_session_factory() as seed_db:
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
<<<<<<< HEAD
|
||||
|
||||
# Start job queue
|
||||
from app.services.job_queue import (
|
||||
job_queue,
|
||||
register_all_handlers,
|
||||
reconcile_stale_processing_tasks,
|
||||
JobType,
|
||||
)
|
||||
|
||||
if settings.STARTUP_RECONCILE_STALE_TASKS:
|
||||
reconciled = await reconcile_stale_processing_tasks()
|
||||
if reconciled:
|
||||
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
|
||||
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
|
||||
# Pre-warm host inventory cache for existing hunts
|
||||
from app.services.host_inventory import inventory_cache
|
||||
async with async_session_factory() as warm_db:
|
||||
from sqlalchemy import select, func
|
||||
from app.db.models import Hunt, Dataset
|
||||
stmt = (
|
||||
select(Hunt.id)
|
||||
.join(Dataset, Dataset.hunt_id == Hunt.id)
|
||||
.group_by(Hunt.id)
|
||||
.having(func.count(Dataset.id) > 0)
|
||||
)
|
||||
result = await warm_db.execute(stmt)
|
||||
hunt_ids = [row[0] for row in result.all()]
|
||||
warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]
|
||||
for hid in warm_hunts:
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)
|
||||
if warm_hunts:
|
||||
logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})")
|
||||
|
||||
# Check which datasets still need processing
|
||||
# (no anomaly results = never fully processed)
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select, exists
|
||||
from app.db.models import Dataset, AnomalyResult
|
||||
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||
has_anomaly = (
|
||||
select(AnomalyResult.id)
|
||||
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||
.limit(1)
|
||||
.correlate(Dataset)
|
||||
.exists()
|
||||
)
|
||||
stmt = select(Dataset.id).where(~has_anomaly)
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
|
||||
if unprocessed_ids:
|
||||
to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]
|
||||
for ds_id in to_reprocess:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})")
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(to_reprocess))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()
|
||||
else:
|
||||
logger.info("All datasets already processed - skipping startup pipeline")
|
||||
|
||||
# Start load balancer health loop
|
||||
from app.services.load_balancer import lb
|
||||
await lb.start_health_loop(interval=30.0)
|
||||
logger.info("Load balancer health loop started")
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down ...")
|
||||
from app.services.job_queue import job_queue as jq
|
||||
await jq.stop()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
from app.services.load_balancer import lb as _lb
|
||||
await _lb.stop_health_loop()
|
||||
logger.info("Load balancer stopped")
|
||||
|
||||
=======
|
||||
yield
|
||||
logger.info("Shutting down …")
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
from app.agents.providers_v2 import cleanup_client
|
||||
from app.services.enrichment import enrichment_engine
|
||||
await cleanup_client()
|
||||
@@ -80,10 +181,18 @@ app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(network_router)
|
||||
<<<<<<< HEAD
|
||||
app.include_router(mitre_router)
|
||||
app.include_router(timeline_router)
|
||||
app.include_router(playbooks_router)
|
||||
app.include_router(searches_router)
|
||||
app.include_router(stix_router)
|
||||
=======
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(cases_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(notebooks_router)
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
@@ -100,3 +209,15 @@ async def root():
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health():
|
||||
return {
|
||||
"service": "ThreatHunt API",
|
||||
"version": settings.APP_VERSION,
|
||||
"status": "ok",
|
||||
}
|
||||
=======
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
|
||||
return name or ''
|
||||
|
||||
|
||||
|
||||
|
||||
# In-memory host inventory cache
|
||||
# Pre-computed results stored per hunt_id, built in background after upload.
|
||||
|
||||
import time as _time
|
||||
|
||||
class _InventoryCache:
|
||||
"""Simple in-memory cache for pre-computed host inventories."""
|
||||
|
||||
def __init__(self):
|
||||
self._data: dict[str, dict] = {} # hunt_id -> result dict
|
||||
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
|
||||
self._building: set[str] = set() # hunt_ids currently being built
|
||||
|
||||
def get(self, hunt_id: str) -> dict | None:
|
||||
"""Return cached result if present. Never expires; only invalidated on new upload."""
|
||||
return self._data.get(hunt_id)
|
||||
|
||||
def put(self, hunt_id: str, result: dict):
|
||||
self._data[hunt_id] = result
|
||||
self._timestamps[hunt_id] = _time.time()
|
||||
self._building.discard(hunt_id)
|
||||
logger.info(f"Cached host inventory for hunt {hunt_id} "
|
||||
f"({result['stats']['total_hosts']} hosts)")
|
||||
|
||||
def invalidate(self, hunt_id: str):
|
||||
self._data.pop(hunt_id, None)
|
||||
self._timestamps.pop(hunt_id, None)
|
||||
|
||||
def is_building(self, hunt_id: str) -> bool:
|
||||
return hunt_id in self._building
|
||||
|
||||
def set_building(self, hunt_id: str):
|
||||
self._building.add(hunt_id)
|
||||
|
||||
def clear_building(self, hunt_id: str):
|
||||
self._building.discard(hunt_id)
|
||||
|
||||
def status(self, hunt_id: str) -> str:
|
||||
if hunt_id in self._building:
|
||||
return "building"
|
||||
if hunt_id in self._data:
|
||||
return "ready"
|
||||
return "none"
|
||||
|
||||
|
||||
inventory_cache = _InventoryCache()
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
@@ -155,29 +205,57 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
sampled_dataset_count = 0
|
||||
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||
global_budget_reached = False
|
||||
dropped_connections = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
|
||||
while True:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.where(DatasetRow.row_index > last_row_index)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
conn_key = (host_key, rip, rport)
|
||||
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||
dropped_connections += 1
|
||||
continue
|
||||
connections[conn_key] += 1
|
||||
|
||||
offset += batch_size
|
||||
if sampled_dataset:
|
||||
sampled_dataset_count += 1
|
||||
logger.info(
|
||||
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
|
||||
last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
if global_budget_reached:
|
||||
logger.info(
|
||||
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||
hunt_id,
|
||||
total_rows,
|
||||
)
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||
"sampled_datasets": sampled_dataset_count,
|
||||
"global_budget_reached": global_budget_reached,
|
||||
"dropped_connections": dropped_connections,
|
||||
},
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
# Velociraptor client IDs (C.hex) are not real hostnames
|
||||
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
@@ -154,7 +158,7 @@ async def profile_host(
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
logger.error("Failed to profile host %s: %r", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
# Filter out Velociraptor client IDs - not real hostnames
|
||||
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
|
||||
skipped = len(hostnames) - len(real_hosts)
|
||||
if skipped:
|
||||
logger.info("Skipped %d Velociraptor client IDs", skipped)
|
||||
hostnames = real_hosts
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
keyword scanning, IOC extraction, and data queries as trackable
|
||||
jobs with status, progress, and cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,6 +34,18 @@ class JobType(str, Enum):
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
HOST_INVENTORY = "host_inventory"
|
||||
KEYWORD_SCAN = "keyword_scan"
|
||||
IOC_EXTRACT = "ioc_extract"
|
||||
|
||||
|
||||
# Job types that form the automatic upload pipeline
|
||||
PIPELINE_JOB_TYPES = frozenset({
|
||||
JobType.TRIAGE,
|
||||
JobType.ANOMALY,
|
||||
JobType.KEYWORD_SCAN,
|
||||
JobType.IOC_EXTRACT,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,11 +96,7 @@ class Job:
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
"""In-memory async job queue with concurrency control."""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
@@ -95,47 +105,56 @@ class JobQueue:
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
def on_completion(self, callback: Callable[[Job], Coroutine]):
|
||||
"""Register a callback invoked after any job completes or fails."""
|
||||
self._completion_callbacks.append(callback)
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
if not self._cleanup_task or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||
self._cleanup_task = None
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
# Soft backpressure: prefer dedupe over queue amplification
|
||||
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||
if dedupe_job is not None:
|
||||
logger.info(
|
||||
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||
)
|
||||
return dedupe_job
|
||||
|
||||
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||
logger.warning(
|
||||
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||
)
|
||||
|
||||
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
@@ -144,6 +163,22 @@ class JobQueue:
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||
if not sig:
|
||||
return None
|
||||
for j in self._jobs.values():
|
||||
if j.job_type != job_type:
|
||||
continue
|
||||
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||
continue
|
||||
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||
if sig == other_sig:
|
||||
return j
|
||||
return None
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
@@ -153,13 +188,7 @@ class JobQueue:
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
@@ -168,7 +197,6 @@ class JobQueue:
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
@@ -177,26 +205,58 @@ class JobQueue:
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
|
||||
}
|
||||
|
||||
def is_backlogged(self) -> bool:
|
||||
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def can_accept(self, reserve: int = 0) -> bool:
|
||||
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
for jid in to_remove:
|
||||
|
||||
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||
terminal_jobs = sorted(
|
||||
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||
key=lambda j: j.created_at,
|
||||
reverse=True,
|
||||
)
|
||||
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||
to_remove.extend([j.id for j in overflow])
|
||||
|
||||
removed = 0
|
||||
for jid in set(to_remove):
|
||||
if jid in self._jobs:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
removed += 1
|
||||
if removed:
|
||||
logger.info(f"Cleaned up {removed} old jobs")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||
while self._started:
|
||||
try:
|
||||
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
|
||||
"""Find all pipeline jobs for a given dataset_id."""
|
||||
return [
|
||||
j for j in self._jobs.values()
|
||||
if j.job_type in PIPELINE_JOB_TYPES
|
||||
and j.params.get("dataset_id") == dataset_id
|
||||
]
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
@@ -220,7 +280,10 @@ class JobQueue:
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
if job.progress <= 0:
|
||||
job.progress = 5.0
|
||||
job.message = "Running..."
|
||||
await _sync_processing_task(job)
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
@@ -231,38 +294,111 @@ class JobQueue:
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||
|
||||
if job.is_cancelled and not job.completed_at:
|
||||
job.completed_at = time.time()
|
||||
|
||||
await _sync_processing_task(job)
|
||||
|
||||
# Fire completion callbacks
|
||||
for cb in self._completion_callbacks:
|
||||
try:
|
||||
await cb(job)
|
||||
except Exception as cb_err:
|
||||
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
|
||||
|
||||
|
||||
async def _sync_processing_task(job: Job):
|
||||
"""Persist latest job state into processing_tasks (if linked by job_id)."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
values = {
|
||||
"status": job.status.value,
|
||||
"progress": float(job.progress),
|
||||
"message": job.message,
|
||||
"error": job.error,
|
||||
}
|
||||
if job.started_at:
|
||||
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
|
||||
if job.completed_at:
|
||||
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.job_id == job.id)
|
||||
.values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
# -- Singleton + job handlers --
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
job_queue = JobQueue(max_workers=5)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
"""Triage handler - chains HOST_PROFILE after completion."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
# Chain: trigger host profiling now that triage results exist
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
|
||||
row = ds.first()
|
||||
hunt_id = row[0] if row else None
|
||||
if hunt_id:
|
||||
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import ProcessingTask
|
||||
async with async_session_factory() as db:
|
||||
existing = await db.execute(
|
||||
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||
)
|
||||
if existing.first() is None:
|
||||
db.add(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=hp_job.id,
|
||||
stage="host_profile",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
await db.commit()
|
||||
except Exception as persist_err:
|
||||
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||
|
||||
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to chain host profile after triage: {e}")
|
||||
|
||||
return {"dataset_id": dataset_id}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
async def _handle_host_inventory(job: Job):
|
||||
from app.db import async_session_factory
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
if not hunt_id:
|
||||
raise ValueError("hunt_id required")
|
||||
|
||||
inventory_cache.set_building(hunt_id)
|
||||
job.message = f"Building host inventory for hunt {hunt_id}"
|
||||
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
inventory_cache.put(hunt_id, result)
|
||||
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
|
||||
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
|
||||
except Exception:
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_keyword_scan(job: Job):
|
||||
"""AUP keyword scan handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
|
||||
# Cache dataset-only result for fast API reuse
|
||||
if dataset_id:
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
|
||||
hits = result.get("total_hits", 0)
|
||||
job.message = f"Keyword scan complete: {hits} hits"
|
||||
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||
|
||||
|
||||
async def _handle_ioc_extract(job: Job):
|
||||
"""IOC extraction handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Extracting IOCs from dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db)
|
||||
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
job.message = f"IOC extraction complete: {total} IOCs found"
|
||||
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
|
||||
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
|
||||
|
||||
|
||||
async def _on_pipeline_job_complete(job: Job):
|
||||
"""Update Dataset.processing_status when all pipeline jobs finish."""
|
||||
if job.job_type not in PIPELINE_JOB_TYPES:
|
||||
return
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
if not dataset_id:
|
||||
return
|
||||
|
||||
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
|
||||
if not pipeline_jobs:
|
||||
return
|
||||
|
||||
all_done = all(
|
||||
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
for j in pipeline_jobs
|
||||
)
|
||||
if not all_done:
|
||||
return
|
||||
|
||||
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
|
||||
new_status = "completed_with_errors" if any_failed else "completed"
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import update
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id == dataset_id)
|
||||
.values(processing_status=new_status)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def reconcile_stale_processing_tasks() -> int:
|
||||
"""Mark queued/running processing tasks from prior runs as failed."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.status.in_(["queued", "running"]))
|
||||
.values(
|
||||
status="failed",
|
||||
error="Recovered after service restart before task completion",
|
||||
message="Recovered stale task after restart",
|
||||
completed_at=now,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
updated = int(result.rowcount or 0)
|
||||
|
||||
if updated:
|
||||
logger.warning(
|
||||
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
|
||||
updated,
|
||||
)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
"""Register all job handlers and completion callbacks."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
|
||||
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
|
||||
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
|
||||
job_queue.on_completion(_on_pipeline_job_complete)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
||||
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
Keyword,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 500
|
||||
BATCH_SIZE = 200
|
||||
|
||||
|
||||
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||
if not data:
|
||||
return None, None
|
||||
|
||||
host_keys = (
|
||||
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||
)
|
||||
user_keys = (
|
||||
'username', 'user_name', 'user', 'account_name',
|
||||
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||
)
|
||||
|
||||
def pick(keys):
|
||||
for k in keys:
|
||||
for actual_key, v in data.items():
|
||||
if actual_key.lower() == k and v not in (None, ''):
|
||||
return str(v)
|
||||
return None
|
||||
|
||||
return pick(host_keys), pick(user_keys)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,6 +64,8 @@ class ScanHit:
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,21 +77,54 @@ class ScanResult:
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordScanCacheEntry:
|
||||
dataset_id: str
|
||||
result: dict
|
||||
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
class KeywordScanCache:
|
||||
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||
|
||||
This enables fast-path reads when users run AUP scans against datasets that
|
||||
were already scanned during upload pipeline processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||
|
||||
def put(self, dataset_id: str, result: dict):
|
||||
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||
|
||||
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||
return self._entries.get(dataset_id)
|
||||
|
||||
def invalidate_dataset(self, dataset_id: str):
|
||||
self._entries.pop(dataset_id, None)
|
||||
|
||||
def clear(self):
|
||||
self._entries.clear()
|
||||
|
||||
|
||||
keyword_scan_cache = KeywordScanCache()
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
# Public API
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = True,
|
||||
scan_annotations: bool = True,
|
||||
scan_messages: bool = True,
|
||||
scan_hunts: bool = False,
|
||||
scan_annotations: bool = False,
|
||||
scan_messages: bool = False,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
@@ -103,7 +163,7 @@ class KeywordScanner:
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────
|
||||
# Internal
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
@@ -143,6 +203,8 @@ class KeywordScanner:
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
hostname: str | None = None,
|
||||
username: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
@@ -150,8 +212,7 @@ class KeywordScanner:
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
# Truncate matched_value for display
|
||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
||||
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
@@ -162,13 +223,14 @@ class KeywordScanner:
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
# Build dataset name lookup
|
||||
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
@@ -178,15 +240,27 @@ class KeywordScanner:
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
# Iterate rows in batches
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
import asyncio
|
||||
|
||||
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||
budget_reached = False
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
|
||||
last_id = 0
|
||||
while True:
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds_id)
|
||||
.where(DatasetRow.id > last_id)
|
||||
.order_by(DatasetRow.id)
|
||||
.limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
@@ -195,21 +269,38 @@ class KeywordScanner:
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
hostname, username = _infer_hostname_and_user(data)
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
dataset_name=ds_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
last_id = rows[-1].id
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
logger.warning(
|
||||
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||
result.rows_scanned,
|
||||
)
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
hunts_result = await self.db.execute(select(Hunt))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
|
||||
ARTIFACT_FOCUS = {
|
||||
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_number)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
124
backend/tests/test_agent_policy_execution.py
Normal file
124
backend/tests/test_agent_policy_execution.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for execution-mode behavior in /api/agent/assist."""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_policy_query_executes_scan(client):
|
||||
# 1) Create hunt
|
||||
h = await client.post("/api/hunts", json={"name": "Policy Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
# 2) Upload browser-history-like CSV
|
||||
csv_bytes = (
|
||||
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||
b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n"
|
||||
)
|
||||
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
# 3) Ensure policy theme/keyword exists
|
||||
t = await client.post(
|
||||
"/api/keywords/themes",
|
||||
json={
|
||||
"name": "Adult Content",
|
||||
"color": "#e91e63",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert t.status_code in (201, 409)
|
||||
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
assert themes.status_code == 200
|
||||
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||
|
||||
k = await client.post(
|
||||
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||
json={"value": "pornhub", "is_regex": False},
|
||||
)
|
||||
assert k.status_code in (201, 409)
|
||||
|
||||
# 4) Execution-mode query
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||
"hunt_id": hunt_id,
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
|
||||
assert body["model_used"] == "execution:keyword_scanner"
|
||||
assert body["execution"] is not None
|
||||
assert body["execution"]["policy_hits"] >= 1
|
||||
assert len(body["execution"]["top_user_hosts"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_execution_preference_off_stays_advisory(client):
|
||||
h = await client.post("/api/hunts", json={"name": "No Exec Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||
"hunt_id": hunt_id,
|
||||
"execution_preference": "off",
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
assert body["model_used"] != "execution:keyword_scanner"
|
||||
assert body["execution"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_execution_preference_force_executes(client):
|
||||
# Create hunt + dataset even when the query text is not policy-specific
|
||||
h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
csv_bytes = (
|
||||
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||
)
|
||||
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
t = await client.post(
|
||||
"/api/keywords/themes",
|
||||
json={"name": "Adult Content", "color": "#e91e63", "enabled": True},
|
||||
)
|
||||
assert t.status_code in (201, 409)
|
||||
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
assert themes.status_code == 200
|
||||
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||
k = await client.post(
|
||||
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||
json={"value": "pornhub", "is_regex": False},
|
||||
)
|
||||
assert k.status_code in (201, 409)
|
||||
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Summarize notable activity in this hunt.",
|
||||
"hunt_id": hunt_id,
|
||||
"execution_preference": "force",
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
assert body["model_used"] == "execution:keyword_scanner"
|
||||
assert body["execution"] is not None
|
||||
@@ -77,6 +77,26 @@ class TestHuntEndpoints:
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_hunt_progress(self, client):
|
||||
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
|
||||
hunt_id = create.json()["id"]
|
||||
|
||||
# attach one dataset so progress has scope
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
res = await client.get(f"/api/hunts/{hunt_id}/progress")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["hunt_id"] == hunt_id
|
||||
assert "progress_percent" in body
|
||||
assert "dataset_total" in body
|
||||
assert "network_status" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDatasetEndpoints:
|
||||
"""Test dataset upload and retrieval."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
|
||||
import pytest
|
||||
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
||||
@@ -43,8 +43,9 @@ class TestCSVParser:
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_parse_empty_file(self):
|
||||
with pytest.raises(Exception):
|
||||
parse_csv_bytes(b"")
|
||||
rows, meta = parse_csv_bytes(b"")
|
||||
assert len(rows) == 0
|
||||
assert meta["row_count"] == 0
|
||||
|
||||
def test_detect_encoding_utf8(self):
|
||||
enc = detect_encoding(SAMPLE_CSV)
|
||||
@@ -53,17 +54,15 @@ class TestCSVParser:
|
||||
|
||||
def test_infer_column_types(self):
|
||||
types = infer_column_types(
|
||||
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
|
||||
"src_ip",
|
||||
[{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}],
|
||||
)
|
||||
assert types == "ip"
|
||||
assert types["src_ip"] == "ip"
|
||||
|
||||
def test_infer_column_types_hash(self):
|
||||
types = infer_column_types(
|
||||
["d41d8cd98f00b204e9800998ecf8427e"],
|
||||
"hash",
|
||||
[{"hash": "d41d8cd98f00b204e9800998ecf8427e"}],
|
||||
)
|
||||
assert types == "hash_md5"
|
||||
assert types["hash"] == "hash_md5"
|
||||
|
||||
|
||||
class TestNormalizer:
|
||||
@@ -94,7 +93,7 @@ class TestNormalizer:
|
||||
start, end = detect_time_range(rows, column_mapping)
|
||||
# Should detect time range from timestamp column
|
||||
if start:
|
||||
assert "2025" in start
|
||||
assert "2025" in str(start)
|
||||
|
||||
def test_normalize_rows(self):
|
||||
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
||||
@@ -102,3 +101,6 @@ class TestNormalizer:
|
||||
normalized = normalize_rows(rows, mapping)
|
||||
assert len(normalized) == 1
|
||||
assert normalized[0].get("src_ip") == "10.0.0.1"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient):
|
||||
assert "total_hits" in data
|
||||
# powershell should match at least one row
|
||||
assert data["total_hits"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_scan_cache_hit(client: AsyncClient):
|
||||
"""Second quick scan should return cache hit metadata."""
|
||||
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert first.status_code == 200
|
||||
assert first.json().get("cache_status") in ("miss", "hit")
|
||||
|
||||
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert second.status_code == 200
|
||||
body = second.json()
|
||||
assert body.get("cache_used") is True
|
||||
assert body.get("cache_status") == "hit"
|
||||
|
||||
84
backend/tests/test_network.py
Normal file
84
backend/tests/test_network.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Tests for network inventory endpoints and cache/polling behavior."""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.host_inventory import inventory_cache
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inventory_status_none_for_unknown_hunt(client):
|
||||
hunt_id = "hunt-does-not-exist"
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["hunt_id"] == hunt_id
|
||||
assert body["status"] == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_inventory_cold_cache_returns_202(client):
|
||||
# Create hunt and upload dataset linked to that hunt
|
||||
hunt = await client.post("/api/hunts", json={"name": "Net Hunt"})
|
||||
hunt_id = hunt.json()["id"]
|
||||
|
||||
files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id})
|
||||
assert up.status_code == 200
|
||||
|
||||
# Ensure cache is cold for this hunt
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 202
|
||||
body = res.json()
|
||||
assert body["status"] == "building"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_inventory_ready_cache_returns_200(client):
|
||||
hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"})
|
||||
hunt_id = hunt.json()["id"]
|
||||
|
||||
mock_inventory = {
|
||||
"hosts": [
|
||||
{
|
||||
"id": "host-1",
|
||||
"hostname": "HOST-1",
|
||||
"fqdn": "HOST-1.local",
|
||||
"client_id": "C.1234abcd",
|
||||
"ips": ["10.0.0.10"],
|
||||
"os": "Windows 10",
|
||||
"users": ["alice"],
|
||||
"datasets": ["test"],
|
||||
"row_count": 5,
|
||||
}
|
||||
],
|
||||
"connections": [],
|
||||
"stats": {
|
||||
"total_hosts": 1,
|
||||
"hosts_with_ips": 1,
|
||||
"hosts_with_users": 1,
|
||||
"total_datasets_scanned": 1,
|
||||
"total_rows_scanned": 5,
|
||||
},
|
||||
}
|
||||
|
||||
inventory_cache.put(hunt_id, mock_inventory)
|
||||
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["stats"]["total_hosts"] == 1
|
||||
assert len(body["hosts"]) == 1
|
||||
assert body["hosts"][0]["hostname"] == "HOST-1"
|
||||
|
||||
status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||
assert status_res.status_code == 200
|
||||
assert status_res.json()["status"] == "ready"
|
||||
82
backend/tests/test_network_scale.py
Normal file
82
backend/tests/test_network_scale.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Scale-oriented network endpoint tests (summary/subgraph/backpressure)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_summary_from_cache(client):
|
||||
hunt_id = "scale-hunt-summary"
|
||||
inv = {
|
||||
"hosts": [
|
||||
{"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50},
|
||||
{"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10},
|
||||
],
|
||||
"connections": [
|
||||
{"source": "h1", "target": "8.8.8.8", "count": 7},
|
||||
{"source": "h1", "target": "h2", "count": 3},
|
||||
],
|
||||
"stats": {"total_hosts": 2, "total_rows_scanned": 60},
|
||||
}
|
||||
inventory_cache.put(hunt_id, inv)
|
||||
|
||||
res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["stats"]["total_hosts"] == 2
|
||||
assert len(body["top_hosts"]) == 1
|
||||
assert body["top_hosts"][0]["id"] == "h1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_subgraph_truncates(client):
|
||||
hunt_id = "scale-hunt-subgraph"
|
||||
inv = {
|
||||
"hosts": [
|
||||
{"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i}
|
||||
for i in range(1, 8)
|
||||
],
|
||||
"connections": [
|
||||
{"source": "h1", "target": "h2", "count": 20},
|
||||
{"source": "h1", "target": "h3", "count": 15},
|
||||
{"source": "h2", "target": "h4", "count": 5},
|
||||
{"source": "h3", "target": "h5", "count": 4},
|
||||
],
|
||||
"stats": {"total_hosts": 7, "total_rows_scanned": 999},
|
||||
}
|
||||
inventory_cache.put(hunt_id, inv)
|
||||
|
||||
res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert len(body["hosts"]) <= 3
|
||||
assert len(body["connections"]) <= 2
|
||||
assert body["stats"]["truncated"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_job_submit_backpressure_returns_429(client):
|
||||
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||
try:
|
||||
res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}})
|
||||
assert res.status_code == 429
|
||||
finally:
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_host_inventory_deferred_when_queue_backlogged(client):
|
||||
hunt_id = "deferred-hunt"
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||
try:
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 202
|
||||
body = res.json()
|
||||
assert body["status"] == "deferred"
|
||||
finally:
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||
203
backend/tests/test_new_features.py
Normal file
203
backend/tests/test_new_features.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches."""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
class TestMitreRoutes:
|
||||
"""Tests for /api/mitre endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mitre_coverage_empty(self, client):
|
||||
resp = await client.get("/api/mitre/coverage")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "tactics" in data
|
||||
assert "technique_count" in data
|
||||
assert data["technique_count"] == 0
|
||||
assert len(data["tactics"]) == 14 # 14 MITRE tactics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mitre_coverage_with_hunt_filter(self, client):
|
||||
resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["technique_count"] == 0
|
||||
|
||||
|
||||
class TestTimelineRoutes:
|
||||
"""Tests for /api/timeline endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeline_hunt_not_found(self, client):
|
||||
resp = await client.get("/api/timeline/hunt/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeline_with_hunt(self, client):
|
||||
# Create a hunt first
|
||||
hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"})
|
||||
assert hunt_resp.status_code in (200, 201)
|
||||
hunt_id = hunt_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/timeline/hunt/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["hunt_id"] == hunt_id
|
||||
assert "events" in data
|
||||
assert "datasets" in data
|
||||
|
||||
|
||||
class TestPlaybookRoutes:
|
||||
"""Tests for /api/playbooks endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_playbooks_empty(self, client):
|
||||
resp = await client.get("/api/playbooks")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["playbooks"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_templates(self, client):
|
||||
resp = await client.get("/api/playbooks/templates")
|
||||
assert resp.status_code == 200
|
||||
templates = resp.json()["templates"]
|
||||
assert len(templates) >= 2
|
||||
assert templates[0]["name"] == "Standard Threat Hunt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_playbook(self, client):
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "My Investigation",
|
||||
"description": "Test playbook",
|
||||
"steps": [
|
||||
{"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"},
|
||||
],
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "My Investigation"
|
||||
assert len(data["steps"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_crud(self, client):
|
||||
# Create
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "CRUD Test",
|
||||
"steps": [{"title": "Do something"}],
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
pb_id = resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "CRUD Test"
|
||||
assert len(resp.json()["steps"]) == 1
|
||||
|
||||
# Update
|
||||
resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Delete
|
||||
resp = await client.delete(f"/api/playbooks/{pb_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_step_completion(self, client):
|
||||
# Create with step
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "Step Test",
|
||||
"steps": [{"title": "Task 1"}],
|
||||
})
|
||||
pb_id = resp.json()["id"]
|
||||
|
||||
# Get to find step ID
|
||||
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||
steps = resp.json()["steps"]
|
||||
step_id = steps[0]["id"]
|
||||
assert steps[0]["is_completed"] is False
|
||||
|
||||
# Mark complete
|
||||
resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["is_completed"] is True
|
||||
|
||||
|
||||
class TestSavedSearchRoutes:
|
||||
"""Tests for /api/searches endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_empty(self, client):
|
||||
resp = await client.get("/api/searches")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["searches"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_saved_search(self, client):
|
||||
resp = await client.post("/api/searches", json={
|
||||
"name": "Suspicious IPs",
|
||||
"search_type": "ioc_search",
|
||||
"query_params": {"ioc_value": "203.0.113"},
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "Suspicious IPs"
|
||||
assert data["search_type"] == "ioc_search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_crud(self, client):
|
||||
# Create
|
||||
resp = await client.post("/api/searches", json={
|
||||
"name": "Test Query",
|
||||
"search_type": "keyword_scan",
|
||||
"query_params": {"theme": "malware"},
|
||||
})
|
||||
s_id = resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/searches/{s_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Update
|
||||
resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Run
|
||||
resp = await client.post(f"/api/searches/{s_id}/run")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "result_count" in data
|
||||
assert "delta" in data
|
||||
|
||||
# Delete
|
||||
resp = await client.delete(f"/api/searches/{s_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
|
||||
class TestStixExport:
|
||||
"""Tests for /api/export/stix endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stix_export_hunt_not_found(self, client):
|
||||
resp = await client.get("/api/export/stix/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stix_export_empty_hunt(self, client):
|
||||
"""Export from a real hunt with no data returns valid but minimal bundle."""
|
||||
hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"})
|
||||
assert hunt_resp.status_code in (200, 201)
|
||||
hunt_id = hunt_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/export/stix/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["type"] == "bundle"
|
||||
assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle
|
||||
assert "objects" in data
|
||||
# At minimum should have the identity object
|
||||
types = [o["type"] for o in data["objects"]]
|
||||
assert "identity" in types
|
||||
|
||||
Binary file not shown.
@@ -7,24 +7,24 @@ services:
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
|
||||
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
|
||||
TH_WILE_HOST: "100.110.190.12"
|
||||
TH_ROADRUNNER_HOST: "100.110.190.11"
|
||||
TH_OLLAMA_PORT: "11434"
|
||||
TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer"
|
||||
|
||||
# ── Database ──
|
||||
# ── Database ──
|
||||
TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db"
|
||||
|
||||
# ── Auth ──
|
||||
# ── Auth ──
|
||||
TH_JWT_SECRET: "change-me-in-production"
|
||||
|
||||
# ── Enrichment API keys (set your own) ──
|
||||
# ── Enrichment API keys (set your own) ──
|
||||
# TH_VIRUSTOTAL_API_KEY: ""
|
||||
# TH_ABUSEIPDB_API_KEY: ""
|
||||
# TH_SHODAN_API_KEY: ""
|
||||
|
||||
# ── Agent behaviour ──
|
||||
# ── Agent behaviour ──
|
||||
TH_AGENT_MAX_TOKENS: "4096"
|
||||
TH_AGENT_TEMPERATURE: "0.3"
|
||||
volumes:
|
||||
@@ -51,7 +51,7 @@ services:
|
||||
networks:
|
||||
- threathunt
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--quiet", "--tries=1", "--spider", "http://localhost:3000/"]
|
||||
test: ["CMD", "curl", "-f", "http://127.0.0.1:3000/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
350
fix_all.py
Normal file
350
fix_all.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""Fix all critical issues: DB locking, keyword scan, network map."""
|
||||
import os, re
|
||||
|
||||
ROOT = r"D:\Projects\Dev\ThreatHunt"
|
||||
|
||||
def fix_file(filepath, replacements):
|
||||
"""Apply text replacements to a file."""
|
||||
path = os.path.join(ROOT, filepath)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
for old, new, desc in replacements:
|
||||
if old in content:
|
||||
content = content.replace(old, new, 1)
|
||||
print(f" OK: {desc}")
|
||||
else:
|
||||
print(f" SKIP: {desc} (pattern not found)")
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return content
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 1: Database engine - NullPool instead of StaticPool
|
||||
# ================================================================
|
||||
print("\n=== FIX 1: Database engine (NullPool + higher timeouts) ===")
|
||||
|
||||
engine_path = os.path.join(ROOT, "backend", "app", "db", "engine.py")
|
||||
with open(engine_path, "r", encoding="utf-8") as f:
|
||||
engine_content = f.read()
|
||||
|
||||
new_engine = '''"""Database engine, session factory, and base model.
|
||||
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||
|
||||
_engine_kwargs: dict = dict(
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||
# NullPool: each session gets its own connection.
|
||||
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||
from sqlalchemy.pool import NullPool
|
||||
_engine_kwargs["poolclass"] = NullPool
|
||||
else:
|
||||
_engine_kwargs["pool_size"] = 5
|
||||
_engine_kwargs["max_overflow"] = 10
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=30000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
# Alias expected by other modules
|
||||
async_session = async_session_factory
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine on shutdown."""
|
||||
await engine.dispose()
|
||||
'''
|
||||
|
||||
with open(engine_path, "w", encoding="utf-8") as f:
|
||||
f.write(new_engine)
|
||||
print(" OK: Replaced StaticPool with NullPool")
|
||||
print(" OK: Increased busy_timeout 5000 -> 30000ms")
|
||||
print(" OK: Added check_same_thread=False")
|
||||
print(" OK: Connection timeout 30 -> 60s")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 2: Keyword scan endpoint - make POST non-blocking (background job)
|
||||
# ================================================================
|
||||
print("\n=== FIX 2: Keyword scan endpoint -> background job ===")
|
||||
|
||||
kw_path = os.path.join(ROOT, "backend", "app", "api", "routes", "keywords.py")
|
||||
with open(kw_path, "r", encoding="utf-8") as f:
|
||||
kw_content = f.read()
|
||||
|
||||
# Replace the scan endpoint to be non-blocking
|
||||
old_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result'''
|
||||
|
||||
new_scan = '''@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources.
|
||||
|
||||
Uses a dedicated DB session separate from the request session
|
||||
to avoid blocking other API requests on SQLite.
|
||||
"""
|
||||
from app.db import async_session_factory
|
||||
async with async_session_factory() as scan_db:
|
||||
scanner = KeywordScanner(scan_db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result'''
|
||||
|
||||
if old_scan in kw_content:
|
||||
kw_content = kw_content.replace(old_scan, new_scan, 1)
|
||||
print(" OK: Scan endpoint uses dedicated DB session")
|
||||
else:
|
||||
print(" SKIP: Scan endpoint pattern not found")
|
||||
|
||||
# Also fix quick_scan
|
||||
old_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result'''
|
||||
|
||||
new_quick = '''@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
from app.db import async_session_factory
|
||||
async with async_session_factory() as scan_db:
|
||||
scanner = KeywordScanner(scan_db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result'''
|
||||
|
||||
if old_quick in kw_content:
|
||||
kw_content = kw_content.replace(old_quick, new_quick, 1)
|
||||
print(" OK: Quick scan uses dedicated DB session")
|
||||
else:
|
||||
print(" SKIP: Quick scan pattern not found")
|
||||
|
||||
with open(kw_path, "w", encoding="utf-8") as f:
|
||||
f.write(kw_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 3: Scanner service - smaller batches, yield between batches
|
||||
# ================================================================
|
||||
print("\n=== FIX 3: Scanner service - smaller batches + async yield ===")
|
||||
|
||||
scanner_path = os.path.join(ROOT, "backend", "app", "services", "scanner.py")
|
||||
with open(scanner_path, "r", encoding="utf-8") as f:
|
||||
scanner_content = f.read()
|
||||
|
||||
# Change batch size and add yield between batches
|
||||
old_batch = "BATCH_SIZE = 500"
|
||||
new_batch = "BATCH_SIZE = 200"
|
||||
|
||||
if old_batch in scanner_content:
|
||||
scanner_content = scanner_content.replace(old_batch, new_batch, 1)
|
||||
print(" OK: Reduced batch size 500 -> 200")
|
||||
|
||||
# Add asyncio.sleep(0) between batches to yield to other tasks
|
||||
old_batch_loop = ''' offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break'''
|
||||
|
||||
new_batch_loop = ''' offset += BATCH_SIZE
|
||||
# Yield to event loop between batches so other requests aren't starved
|
||||
import asyncio
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break'''
|
||||
|
||||
if old_batch_loop in scanner_content:
|
||||
scanner_content = scanner_content.replace(old_batch_loop, new_batch_loop, 1)
|
||||
print(" OK: Added async yield between scan batches")
|
||||
else:
|
||||
print(" SKIP: Batch loop pattern not found")
|
||||
|
||||
with open(scanner_path, "w", encoding="utf-8") as f:
|
||||
f.write(scanner_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 4: Job queue workers - increase from 3 to 5
|
||||
# ================================================================
|
||||
print("\n=== FIX 4: Job queue - more workers ===")
|
||||
|
||||
jq_path = os.path.join(ROOT, "backend", "app", "services", "job_queue.py")
|
||||
with open(jq_path, "r", encoding="utf-8") as f:
|
||||
jq_content = f.read()
|
||||
|
||||
old_workers = "job_queue = JobQueue(max_workers=3)"
|
||||
new_workers = "job_queue = JobQueue(max_workers=5)"
|
||||
|
||||
if old_workers in jq_content:
|
||||
jq_content = jq_content.replace(old_workers, new_workers, 1)
|
||||
print(" OK: Workers 3 -> 5")
|
||||
|
||||
with open(jq_path, "w", encoding="utf-8") as f:
|
||||
f.write(jq_content)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# FIX 5: main.py - always re-run pipeline on startup for ALL datasets
|
||||
# ================================================================
|
||||
print("\n=== FIX 5: Startup reprocessing - all datasets, not just 'ready' ===")
|
||||
|
||||
main_path = os.path.join(ROOT, "backend", "app", "main.py")
|
||||
with open(main_path, "r", encoding="utf-8") as f:
|
||||
main_content = f.read()
|
||||
|
||||
# The current startup only reprocesses datasets with status="ready"
|
||||
# But after previous runs, they're all "completed" - so nothing happens
|
||||
# Fix: reprocess datasets that have NO triage/anomaly results in DB
|
||||
old_reprocess = ''' # Reprocess datasets that were never fully processed (status still "ready")
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import Dataset
|
||||
stmt = select(Dataset.id).where(Dataset.processing_status == "ready")
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
for ds_id in unprocessed_ids:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
if unprocessed_ids:
|
||||
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||
# Mark them as processing
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(unprocessed_ids))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()'''
|
||||
|
||||
new_reprocess = ''' # Check which datasets still need processing
|
||||
# (no anomaly results = never fully processed)
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select, exists
|
||||
from app.db.models import Dataset, AnomalyResult
|
||||
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||
has_anomaly = (
|
||||
select(AnomalyResult.id)
|
||||
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||
.limit(1)
|
||||
.correlate(Dataset)
|
||||
.exists()
|
||||
)
|
||||
stmt = select(Dataset.id).where(~has_anomaly)
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
|
||||
if unprocessed_ids:
|
||||
for ds_id in unprocessed_ids:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
logger.info(f"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets")
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(unprocessed_ids))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()
|
||||
else:
|
||||
logger.info("All datasets already processed - skipping startup pipeline")'''
|
||||
|
||||
if old_reprocess in main_content:
|
||||
main_content = main_content.replace(old_reprocess, new_reprocess, 1)
|
||||
print(" OK: Startup checks for actual results, not just status field")
|
||||
else:
|
||||
print(" SKIP: Reprocess block not found")
|
||||
|
||||
with open(main_path, "w", encoding="utf-8") as f:
|
||||
f.write(main_content)
|
||||
|
||||
|
||||
print("\n=== ALL FIXES APPLIED ===")
|
||||
30
fix_keywords.py
Normal file
30
fix_keywords.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import re
|
||||
|
||||
path = "backend/app/api/routes/keywords.py"
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
c = f.read()
|
||||
|
||||
# Fix POST /scan - remove dedicated session, use injected db
|
||||
old1 = ' """Run AUP keyword scan across selected data sources.\n \n Uses a dedicated DB session separate from the request session\n to avoid blocking other API requests on SQLite.\n """\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
|
||||
|
||||
new1 = ' """Run AUP keyword scan across selected data sources."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(\n dataset_ids=body.dataset_ids,\n theme_ids=body.theme_ids,\n scan_hunts=body.scan_hunts,\n scan_annotations=body.scan_annotations,\n scan_messages=body.scan_messages,\n )\n return result'
|
||||
|
||||
if old1 in c:
|
||||
c = c.replace(old1, new1, 1)
|
||||
print("OK: reverted POST /scan")
|
||||
else:
|
||||
print("SKIP: POST /scan not found")
|
||||
|
||||
# Fix GET /scan/quick
|
||||
old2 = ' """Quick scan a single dataset with all enabled themes."""\n from app.db import async_session_factory\n async with async_session_factory() as scan_db:\n scanner = KeywordScanner(scan_db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
|
||||
|
||||
new2 = ' """Quick scan a single dataset with all enabled themes."""\n scanner = KeywordScanner(db)\n result = await scanner.scan(dataset_ids=[dataset_id])\n return result'
|
||||
|
||||
if old2 in c:
|
||||
c = c.replace(old2, new2, 1)
|
||||
print("OK: reverted GET /scan/quick")
|
||||
else:
|
||||
print("SKIP: GET /scan/quick not found")
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(c)
|
||||
@@ -16,6 +16,12 @@ server {
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_read_timeout 300s;
|
||||
|
||||
# SSE streaming support for agent assist
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_set_header Connection '';
|
||||
chunked_transfer_encoding off;
|
||||
}
|
||||
|
||||
# SPA fallback serve index.html for all non-file routes
|
||||
|
||||
58
frontend/package-lock.json
generated
58
frontend/package-lock.json
generated
@@ -24,9 +24,13 @@
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-router-dom": "^7.13.0",
|
||||
"react-scripts": "5.0.1",
|
||||
<<<<<<< HEAD
|
||||
"recharts": "^3.7.0"
|
||||
=======
|
||||
"recharts": "^3.7.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"yaml": "^2.8.2"
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/cytoscape": "^3.21.9",
|
||||
@@ -3978,6 +3982,8 @@
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/@types/cytoscape": {
|
||||
"version": "3.21.9",
|
||||
"resolved": "https://registry.npmjs.org/@types/cytoscape/-/cytoscape-3.21.9.tgz",
|
||||
@@ -3985,6 +3991,7 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/@types/d3-array": {
|
||||
"version": "3.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.2.tgz",
|
||||
@@ -4048,6 +4055,8 @@
|
||||
"integrity": "sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/@types/debug": {
|
||||
"version": "4.1.12",
|
||||
"resolved": "https://registry.npmjs.org/@types/debug/-/debug-4.1.12.tgz",
|
||||
@@ -4057,6 +4066,7 @@
|
||||
"@types/ms": "*"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/@types/eslint": {
|
||||
"version": "8.56.12",
|
||||
"resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz",
|
||||
@@ -4415,12 +4425,15 @@
|
||||
"integrity": "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==",
|
||||
"license": "MIT"
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/@types/unist": {
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.3.tgz",
|
||||
"integrity": "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/@types/use-sync-external-store": {
|
||||
"version": "0.0.6",
|
||||
"resolved": "https://registry.npmjs.org/@types/use-sync-external-store/-/use-sync-external-store-0.0.6.tgz",
|
||||
@@ -7026,6 +7039,8 @@
|
||||
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/cytoscape": {
|
||||
"version": "3.33.1",
|
||||
"resolved": "https://registry.npmjs.org/cytoscape/-/cytoscape-3.33.1.tgz",
|
||||
@@ -7060,6 +7075,7 @@
|
||||
"cytoscape": "^3.2.22"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/d3-array": {
|
||||
"version": "3.2.4",
|
||||
"resolved": "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz",
|
||||
@@ -7081,6 +7097,8 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/d3-dispatch": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-1.0.6.tgz",
|
||||
@@ -7097,6 +7115,7 @@
|
||||
"d3-selection": "1"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/d3-ease": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz",
|
||||
@@ -7152,12 +7171,15 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/d3-selection": {
|
||||
"version": "1.4.2",
|
||||
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-1.4.2.tgz",
|
||||
"integrity": "sha512-SJ0BqYihzOjDnnlfyeHT0e30k0K1+5sR3d5fNueCNeuhZTnGw4M4o8mqJchSwgKMXCNFo+e2VTChiSJ0vYtXkg==",
|
||||
"license": "BSD-3-Clause"
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/d3-shape": {
|
||||
"version": "3.2.0",
|
||||
"resolved": "https://registry.npmjs.org/d3-shape/-/d3-shape-3.2.0.tgz",
|
||||
@@ -7203,6 +7225,8 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/dagre": {
|
||||
"version": "0.8.5",
|
||||
"resolved": "https://registry.npmjs.org/dagre/-/dagre-0.8.5.tgz",
|
||||
@@ -7213,6 +7237,7 @@
|
||||
"lodash": "^4.17.15"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/damerau-levenshtein": {
|
||||
"version": "1.0.8",
|
||||
"resolved": "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz",
|
||||
@@ -7313,6 +7338,8 @@
|
||||
"integrity": "sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/decode-named-character-reference": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/decode-named-character-reference/-/decode-named-character-reference-1.3.0.tgz",
|
||||
@@ -7326,6 +7353,7 @@
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/dedent": {
|
||||
"version": "0.7.0",
|
||||
"resolved": "https://registry.npmjs.org/dedent/-/dedent-0.7.0.tgz",
|
||||
@@ -15841,6 +15869,29 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-redux": {
|
||||
"version": "9.2.0",
|
||||
"resolved": "https://registry.npmjs.org/react-redux/-/react-redux-9.2.0.tgz",
|
||||
"integrity": "sha512-ROY9fvHhwOD9ySfrF0wmvu//bKCQ6AeZZq1nJNtbDC+kk5DuSuNX/n6YWYF/SYy7bSba4D4FSz8DJeKY/S/r+g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/use-sync-external-store": "^0.0.6",
|
||||
"use-sync-external-store": "^1.4.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "^18.2.25 || ^19",
|
||||
"react": "^18.0 || ^19",
|
||||
"redux": "^5.0.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"redux": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-refresh": {
|
||||
"version": "0.11.0",
|
||||
"resolved": "https://registry.npmjs.org/react-refresh/-/react-refresh-0.11.0.tgz",
|
||||
@@ -16074,8 +16125,12 @@
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/redux/-/redux-5.0.1.tgz",
|
||||
"integrity": "sha512-M9/ELqF6fy8FwmkpnF0S3YKOqMyoWJ4+CS5Efg2ct3oY9daQvd/Pc71FpGZsVsbl3Cpb+IIcjBDUnnyBdQbq4w==",
|
||||
<<<<<<< HEAD
|
||||
"license": "MIT"
|
||||
=======
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
},
|
||||
"node_modules/redux-thunk": {
|
||||
"version": "3.1.0",
|
||||
@@ -18783,6 +18838,8 @@
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
"node_modules/vfile": {
|
||||
"version": "6.0.3",
|
||||
"resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz",
|
||||
@@ -18811,6 +18868,7 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
"node_modules/victory-vendor": {
|
||||
"version": "37.3.6",
|
||||
"resolved": "https://registry.npmjs.org/victory-vendor/-/victory-vendor-37.3.6.tgz",
|
||||
|
||||
@@ -19,9 +19,13 @@
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-router-dom": "^7.13.0",
|
||||
"react-scripts": "5.0.1",
|
||||
<<<<<<< HEAD
|
||||
"recharts": "^3.7.0"
|
||||
=======
|
||||
"recharts": "^3.7.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"yaml": "^2.8.2"
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
* ThreatHunt — MUI-powered analyst-assist platform.
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import React, { useState, useCallback, Suspense } from 'react';
|
||||
import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom';
|
||||
import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton,
|
||||
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip } from '@mui/material';
|
||||
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip,
|
||||
CircularProgress } from '@mui/material';
|
||||
import MenuIcon from '@mui/icons-material/Menu';
|
||||
import DashboardIcon from '@mui/icons-material/Dashboard';
|
||||
import SearchIcon from '@mui/icons-material/Search';
|
||||
@@ -18,6 +19,13 @@ import ScienceIcon from '@mui/icons-material/Science';
|
||||
import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
|
||||
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
|
||||
import HubIcon from '@mui/icons-material/Hub';
|
||||
<<<<<<< HEAD
|
||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||
import TimelineIcon from '@mui/icons-material/Timeline';
|
||||
import PlaylistAddCheckIcon from '@mui/icons-material/PlaylistAddCheck';
|
||||
import BookmarksIcon from '@mui/icons-material/Bookmarks';
|
||||
import ShieldIcon from '@mui/icons-material/Shield';
|
||||
=======
|
||||
import DevicesIcon from '@mui/icons-material/Devices';
|
||||
import AccountTreeIcon from '@mui/icons-material/AccountTree';
|
||||
import TimelineIcon from '@mui/icons-material/Timeline';
|
||||
@@ -29,9 +37,11 @@ import WorkIcon from '@mui/icons-material/Work';
|
||||
import NotificationsActiveIcon from '@mui/icons-material/NotificationsActive';
|
||||
import MenuBookIcon from '@mui/icons-material/MenuBook';
|
||||
import PlaylistPlayIcon from '@mui/icons-material/PlaylistPlay';
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
import { SnackbarProvider } from 'notistack';
|
||||
import theme from './theme';
|
||||
|
||||
/* -- Eager imports (lightweight, always needed) -- */
|
||||
import Dashboard from './components/Dashboard';
|
||||
import HuntManager from './components/HuntManager';
|
||||
import DatasetViewer from './components/DatasetViewer';
|
||||
@@ -42,6 +52,16 @@ import AnnotationPanel from './components/AnnotationPanel';
|
||||
import HypothesisTracker from './components/HypothesisTracker';
|
||||
import CorrelationView from './components/CorrelationView';
|
||||
import AUPScanner from './components/AUPScanner';
|
||||
<<<<<<< HEAD
|
||||
|
||||
/* -- Lazy imports (heavy: charts, network graph, new feature pages) -- */
|
||||
const NetworkMap = React.lazy(() => import('./components/NetworkMap'));
|
||||
const AnalysisDashboard = React.lazy(() => import('./components/AnalysisDashboard'));
|
||||
const MitreMatrix = React.lazy(() => import('./components/MitreMatrix'));
|
||||
const TimelineView = React.lazy(() => import('./components/TimelineView'));
|
||||
const PlaybookManager = React.lazy(() => import('./components/PlaybookManager'));
|
||||
const SavedSearches = React.lazy(() => import('./components/SavedSearches'));
|
||||
=======
|
||||
import NetworkMap from './components/NetworkMap';
|
||||
import NetworkPicture from './components/NetworkPicture';
|
||||
import ProcessTree from './components/ProcessTree';
|
||||
@@ -54,12 +74,31 @@ import CaseManager from './components/CaseManager';
|
||||
import AlertPanel from './components/AlertPanel';
|
||||
import InvestigationNotebook from './components/InvestigationNotebook';
|
||||
import PlaybookManager from './components/PlaybookManager';
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
const DRAWER_WIDTH = 240;
|
||||
|
||||
interface NavItem { label: string; path: string; icon: React.ReactNode }
|
||||
|
||||
const NAV: NavItem[] = [
|
||||
<<<<<<< HEAD
|
||||
{ label: 'Dashboard', path: '/', icon: <DashboardIcon /> },
|
||||
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
|
||||
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
|
||||
{ label: 'Upload', path: '/upload', icon: <UploadFileIcon /> },
|
||||
{ label: 'AI Analysis', path: '/analysis', icon: <AssessmentIcon /> },
|
||||
{ label: 'Agent', path: '/agent', icon: <SmartToyIcon /> },
|
||||
{ label: 'Enrichment', path: '/enrichment', icon: <SecurityIcon /> },
|
||||
{ label: 'Annotations', path: '/annotations', icon: <BookmarkIcon /> },
|
||||
{ label: 'Hypotheses', path: '/hypotheses', icon: <ScienceIcon /> },
|
||||
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
|
||||
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||
{ label: 'MITRE Matrix', path: '/mitre', icon: <ShieldIcon /> },
|
||||
{ label: 'Timeline', path: '/timeline', icon: <TimelineIcon /> },
|
||||
{ label: 'Playbooks', path: '/playbooks', icon: <PlaylistAddCheckIcon /> },
|
||||
{ label: 'Saved Searches', path: '/saved-searches', icon: <BookmarksIcon /> },
|
||||
=======
|
||||
{ label: 'Dashboard', path: '/', icon: <DashboardIcon /> },
|
||||
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
|
||||
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
|
||||
@@ -82,8 +121,17 @@ const NAV: NavItem[] = [
|
||||
{ label: 'Notebooks', path: '/notebooks', icon: <MenuBookIcon /> },
|
||||
{ label: 'Playbooks', path: '/playbooks', icon: <PlaylistPlayIcon /> },
|
||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
];
|
||||
|
||||
function LazyFallback() {
|
||||
return (
|
||||
<Box sx={{ display: 'flex', justifyContent: 'center', alignItems: 'center', minHeight: 200 }}>
|
||||
<CircularProgress />
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
function Shell() {
|
||||
const [open, setOpen] = useState(true);
|
||||
const navigate = useNavigate();
|
||||
@@ -102,7 +150,7 @@ function Shell() {
|
||||
<Typography variant="h6" noWrap sx={{ flexGrow: 1 }}>
|
||||
ThreatHunt
|
||||
</Typography>
|
||||
<Chip label="v0.3.0" size="small" color="primary" variant="outlined" />
|
||||
<Chip label="v0.4.0" size="small" color="primary" variant="outlined" />
|
||||
</Toolbar>
|
||||
</AppBar>
|
||||
|
||||
@@ -137,6 +185,28 @@ function Shell() {
|
||||
ml: open ? 0 : `-${DRAWER_WIDTH}px`,
|
||||
transition: 'margin 225ms cubic-bezier(0,0,0.2,1)',
|
||||
}}>
|
||||
<<<<<<< HEAD
|
||||
<Suspense fallback={<LazyFallback />}>
|
||||
<Routes>
|
||||
<Route path="/" element={<Dashboard />} />
|
||||
<Route path="/hunts" element={<HuntManager />} />
|
||||
<Route path="/datasets" element={<DatasetViewer />} />
|
||||
<Route path="/upload" element={<FileUpload />} />
|
||||
<Route path="/analysis" element={<AnalysisDashboard />} />
|
||||
<Route path="/agent" element={<AgentPanel />} />
|
||||
<Route path="/enrichment" element={<EnrichmentPanel />} />
|
||||
<Route path="/annotations" element={<AnnotationPanel />} />
|
||||
<Route path="/hypotheses" element={<HypothesisTracker />} />
|
||||
<Route path="/correlation" element={<CorrelationView />} />
|
||||
<Route path="/network" element={<NetworkMap />} />
|
||||
<Route path="/aup" element={<AUPScanner />} />
|
||||
<Route path="/mitre" element={<MitreMatrix />} />
|
||||
<Route path="/timeline" element={<TimelineView />} />
|
||||
<Route path="/playbooks" element={<PlaybookManager />} />
|
||||
<Route path="/saved-searches" element={<SavedSearches />} />
|
||||
</Routes>
|
||||
</Suspense>
|
||||
=======
|
||||
<Routes>
|
||||
<Route path="/" element={<Dashboard />} />
|
||||
<Route path="/hunts" element={<HuntManager />} />
|
||||
@@ -161,6 +231,7 @@ function Shell() {
|
||||
<Route path="/playbooks" element={<PlaybookManager />} />
|
||||
<Route path="/aup" element={<AUPScanner />} />
|
||||
</Routes>
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
|
||||
@@ -71,8 +71,24 @@ export interface Hunt {
|
||||
dataset_count: number; hypothesis_count: number;
|
||||
}
|
||||
|
||||
<<<<<<< HEAD
|
||||
export interface HuntProgress {
|
||||
hunt_id: string;
|
||||
status: 'idle' | 'processing' | 'ready';
|
||||
progress_percent: number;
|
||||
dataset_total: number;
|
||||
dataset_completed: number;
|
||||
dataset_processing: number;
|
||||
dataset_errors: number;
|
||||
active_jobs: number;
|
||||
queued_jobs: number;
|
||||
network_status: 'none' | 'building' | 'ready';
|
||||
stages: Record<string, any>;
|
||||
}
|
||||
=======
|
||||
/** Alias kept for backward-compat with components that import HuntOut */
|
||||
export type HuntOut = Hunt;
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
export const hunts = {
|
||||
list: (skip = 0, limit = 50) =>
|
||||
@@ -83,6 +99,7 @@ export const hunts = {
|
||||
update: (id: string, data: Partial<{ name: string; description: string; status: string }>) =>
|
||||
api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
|
||||
progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),
|
||||
};
|
||||
|
||||
// ── Datasets ─────────────────────────────────────────────────────────
|
||||
@@ -167,6 +184,8 @@ export interface AssistRequest {
|
||||
active_hypotheses?: string[]; annotations_summary?: string;
|
||||
enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate';
|
||||
model_override?: string; conversation_id?: string; hunt_id?: string;
|
||||
execution_preference?: 'auto' | 'force' | 'off';
|
||||
learning_mode?: boolean;
|
||||
}
|
||||
|
||||
export interface AssistResponse {
|
||||
@@ -175,6 +194,15 @@ export interface AssistResponse {
|
||||
sans_references: string[]; model_used: string; node_used: string;
|
||||
latency_ms: number; perspectives: Record<string, any>[] | null;
|
||||
conversation_id: string | null;
|
||||
execution?: {
|
||||
scope: string;
|
||||
datasets_scanned: string[];
|
||||
policy_hits: number;
|
||||
result_count: number;
|
||||
top_domains: string[];
|
||||
top_user_hosts: string[];
|
||||
elapsed_ms: number;
|
||||
} | null;
|
||||
}
|
||||
|
||||
export interface NodeInfo { url: string; available: boolean }
|
||||
@@ -570,10 +598,12 @@ export interface ScanHit {
|
||||
theme_name: string; theme_color: string; keyword: string;
|
||||
source_type: string; source_id: string | number; field: string;
|
||||
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||
hostname?: string | null; username?: string | null;
|
||||
}
|
||||
export interface ScanResponse {
|
||||
total_hits: number; hits: ScanHit[]; themes_scanned: number;
|
||||
keywords_scanned: number; rows_scanned: number;
|
||||
cache_used?: boolean; cache_status?: string; cached_at?: string | null;
|
||||
}
|
||||
|
||||
export const keywords = {
|
||||
@@ -607,6 +637,7 @@ export const keywords = {
|
||||
scan: (opts: {
|
||||
dataset_ids?: string[]; theme_ids?: string[];
|
||||
scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;
|
||||
prefer_cache?: boolean; force_rescan?: boolean;
|
||||
}) =>
|
||||
api<ScanResponse>('/api/keywords/scan', {
|
||||
method: 'POST', body: JSON.stringify(opts),
|
||||
@@ -865,6 +896,224 @@ export const notebooks = {
|
||||
api(`/api/notebooks/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
<<<<<<< HEAD
|
||||
export interface HostInventory {
|
||||
hosts: InventoryHost[];
|
||||
connections: InventoryConnection[];
|
||||
stats: InventoryStats;
|
||||
}
|
||||
|
||||
export interface InventoryStatus {
|
||||
hunt_id: string;
|
||||
status: 'ready' | 'building' | 'none';
|
||||
}
|
||||
|
||||
export interface NetworkSummaryHost {
|
||||
id: string;
|
||||
hostname: string;
|
||||
row_count: number;
|
||||
ip_count: number;
|
||||
user_count: number;
|
||||
}
|
||||
|
||||
export interface NetworkSummary {
|
||||
stats: InventoryStats;
|
||||
top_hosts: NetworkSummaryHost[];
|
||||
top_edges: InventoryConnection[];
|
||||
status?: 'building' | 'deferred';
|
||||
message?: string;
|
||||
}
|
||||
|
||||
export const network = {
|
||||
hostInventory: (huntId: string, force = false) =>
|
||||
api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),
|
||||
summary: (huntId: string, topN = 20) =>
|
||||
api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),
|
||||
subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {
|
||||
let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;
|
||||
if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;
|
||||
return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);
|
||||
},
|
||||
inventoryStatus: (huntId: string) =>
|
||||
api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),
|
||||
rebuildInventory: (huntId: string) =>
|
||||
api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),
|
||||
};
|
||||
|
||||
// -- MITRE ATT&CK Coverage (Feature 1) --
|
||||
|
||||
export interface MitreTechnique {
|
||||
id: string;
|
||||
tactic: string;
|
||||
sources: { type: string; risk_score?: number; hostname?: string; title?: string }[];
|
||||
count: number;
|
||||
}
|
||||
|
||||
export interface MitreCoverage {
|
||||
tactics: string[];
|
||||
technique_count: number;
|
||||
detection_count: number;
|
||||
tactic_coverage: Record<string, { techniques: MitreTechnique[]; count: number }>;
|
||||
all_techniques: MitreTechnique[];
|
||||
}
|
||||
|
||||
export const mitre = {
|
||||
coverage: (huntId?: string) => {
|
||||
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||
return api<MitreCoverage>(`/api/mitre/coverage${q}`);
|
||||
},
|
||||
};
|
||||
|
||||
// -- Timeline (Feature 2) --
|
||||
|
||||
export interface TimelineEvent {
|
||||
timestamp: string;
|
||||
dataset_id: string;
|
||||
dataset_name: string;
|
||||
artifact_type: string;
|
||||
row_index: number;
|
||||
hostname: string;
|
||||
process: string;
|
||||
summary: string;
|
||||
data: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface TimelineData {
|
||||
hunt_id: string;
|
||||
hunt_name: string;
|
||||
event_count: number;
|
||||
datasets: { id: string; name: string; artifact_type: string; row_count: number }[];
|
||||
events: TimelineEvent[];
|
||||
}
|
||||
|
||||
export const timeline = {
|
||||
getHuntTimeline: (huntId: string, limit = 2000) =>
|
||||
api<TimelineData>(`/api/timeline/hunt/${huntId}?limit=${limit}`),
|
||||
};
|
||||
|
||||
// -- Playbooks (Feature 3) --
|
||||
|
||||
export interface PlaybookStep {
|
||||
id: number;
|
||||
order_index: number;
|
||||
title: string;
|
||||
description: string | null;
|
||||
step_type: string;
|
||||
target_route: string | null;
|
||||
is_completed: boolean;
|
||||
completed_at: string | null;
|
||||
notes: string | null;
|
||||
}
|
||||
|
||||
export interface PlaybookSummary {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
is_template: boolean;
|
||||
hunt_id: string | null;
|
||||
status: string;
|
||||
total_steps: number;
|
||||
completed_steps: number;
|
||||
created_at: string | null;
|
||||
}
|
||||
|
||||
export interface PlaybookDetail {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
is_template: boolean;
|
||||
hunt_id: string | null;
|
||||
status: string;
|
||||
created_at: string | null;
|
||||
steps: PlaybookStep[];
|
||||
}
|
||||
|
||||
export interface PlaybookTemplate {
|
||||
name: string;
|
||||
description: string;
|
||||
steps: { title: string; description: string; step_type: string; target_route: string }[];
|
||||
}
|
||||
|
||||
export const playbooks = {
|
||||
list: (huntId?: string) => {
|
||||
const q = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||
return api<{ playbooks: PlaybookSummary[] }>(`/api/playbooks${q}`);
|
||||
},
|
||||
templates: () => api<{ templates: PlaybookTemplate[] }>('/api/playbooks/templates'),
|
||||
get: (id: string) => api<PlaybookDetail>(`/api/playbooks/${id}`),
|
||||
create: (data: { name: string; description?: string; hunt_id?: string; is_template?: boolean; steps?: { title: string; description?: string; step_type?: string; target_route?: string }[] }) =>
|
||||
api<PlaybookDetail>('/api/playbooks', { method: 'POST', body: JSON.stringify(data) }),
|
||||
update: (id: string, data: { name?: string; description?: string; status?: string }) =>
|
||||
api(`/api/playbooks/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/playbooks/${id}`, { method: 'DELETE' }),
|
||||
updateStep: (stepId: number, data: { is_completed?: boolean; notes?: string }) =>
|
||||
api(`/api/playbooks/steps/${stepId}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
};
|
||||
|
||||
// -- Saved Searches (Feature 5) --
|
||||
|
||||
export interface SavedSearchData {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string | null;
|
||||
search_type: string;
|
||||
query_params: Record<string, any>;
|
||||
hunt_id?: string | null;
|
||||
threshold: number | null;
|
||||
last_run_at: string | null;
|
||||
last_result_count: number | null;
|
||||
created_at: string | null;
|
||||
}
|
||||
|
||||
export interface SearchRunResult {
|
||||
search_id: string;
|
||||
search_name: string;
|
||||
search_type: string;
|
||||
result_count: number;
|
||||
previous_count: number;
|
||||
delta: number;
|
||||
results: any[];
|
||||
}
|
||||
|
||||
export const savedSearches = {
|
||||
list: (searchType?: string) => {
|
||||
const q = searchType ? `?search_type=${encodeURIComponent(searchType)}` : '';
|
||||
return api<{ searches: SavedSearchData[] }>(`/api/searches${q}`);
|
||||
},
|
||||
get: (id: string) => api<SavedSearchData>(`/api/searches/${id}`),
|
||||
create: (data: { name: string; description?: string; search_type: string; query_params: Record<string, any>; threshold?: number; hunt_id?: string }) =>
|
||||
api<SavedSearchData>('/api/searches', { method: 'POST', body: JSON.stringify(data) }),
|
||||
update: (id: string, data: { name?: string; description?: string; search_type?: string; query_params?: Record<string, any>; threshold?: number; hunt_id?: string }) =>
|
||||
api(`/api/searches/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/searches/${id}`, { method: 'DELETE' }),
|
||||
run: (id: string) => api<SearchRunResult>(`/api/searches/${id}/run`, { method: 'POST' }),
|
||||
};
|
||||
|
||||
// -- STIX Export --
|
||||
|
||||
export const stixExport = {
|
||||
/** Download a STIX 2.1 bundle JSON for a given hunt */
|
||||
download: async (huntId: string): Promise<void> => {
|
||||
const headers: Record<string, string> = {};
|
||||
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
|
||||
const res = await fetch(`${BASE}/api/export/stix/${huntId}`, { headers });
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.detail || `HTTP ${res.status}`);
|
||||
}
|
||||
const blob = await res.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `hunt-${huntId}-stix-bundle.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
},
|
||||
};
|
||||
|
||||
=======
|
||||
export const playbooks = {
|
||||
templates: () =>
|
||||
api<{ templates: PlaybookTemplate[] }>('/api/notebooks/playbooks/templates'),
|
||||
@@ -887,3 +1136,4 @@ export const playbooks = {
|
||||
abortRun: (runId: string) =>
|
||||
api<PlaybookRunData>(`/api/notebooks/playbooks/runs/${runId}/abort`, { method: 'POST' }),
|
||||
};
|
||||
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
|
||||
|
||||
@@ -188,11 +188,13 @@ const RESULT_COLUMNS: GridColDef[] = [
|
||||
),
|
||||
},
|
||||
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
|
||||
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
|
||||
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
|
||||
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
|
||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
|
||||
{ field: 'field', headerName: 'Field', width: 130 },
|
||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
|
||||
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
|
||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
|
||||
];
|
||||
|
||||
export default function AUPScanner() {
|
||||
@@ -210,9 +212,9 @@ export default function AUPScanner() {
|
||||
// Scan options
|
||||
const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set());
|
||||
const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set());
|
||||
const [scanHunts, setScanHunts] = useState(true);
|
||||
const [scanAnnotations, setScanAnnotations] = useState(true);
|
||||
const [scanMessages, setScanMessages] = useState(true);
|
||||
const [scanHunts, setScanHunts] = useState(false);
|
||||
const [scanAnnotations, setScanAnnotations] = useState(false);
|
||||
const [scanMessages, setScanMessages] = useState(false);
|
||||
|
||||
// Load themes + hunts
|
||||
const loadData = useCallback(async () => {
|
||||
@@ -224,9 +226,13 @@ export default function AUPScanner() {
|
||||
]);
|
||||
setThemes(tRes.themes);
|
||||
setHuntList(hRes.hunts);
|
||||
if (!selectedHuntId && hRes.hunts.length > 0) {
|
||||
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
|
||||
setSelectedHuntId(best.id);
|
||||
}
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [enqueueSnackbar]);
|
||||
}, [enqueueSnackbar, selectedHuntId]);
|
||||
|
||||
useEffect(() => { loadData(); }, [loadData]);
|
||||
|
||||
@@ -237,7 +243,7 @@ export default function AUPScanner() {
|
||||
datasets.list(0, 500, selectedHuntId).then(res => {
|
||||
if (cancelled) return;
|
||||
setDsList(res.datasets);
|
||||
setSelectedDs(new Set(res.datasets.map(d => d.id)));
|
||||
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
|
||||
}).catch(() => {});
|
||||
return () => { cancelled = true; };
|
||||
}, [selectedHuntId]);
|
||||
@@ -251,6 +257,15 @@ export default function AUPScanner() {
|
||||
|
||||
// Run scan
|
||||
const runScan = useCallback(async () => {
|
||||
if (!selectedHuntId) {
|
||||
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
|
||||
return;
|
||||
}
|
||||
if (selectedDs.size === 0) {
|
||||
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
|
||||
return;
|
||||
}
|
||||
|
||||
setScanning(true);
|
||||
setScanResult(null);
|
||||
try {
|
||||
@@ -260,6 +275,7 @@ export default function AUPScanner() {
|
||||
scan_hunts: scanHunts,
|
||||
scan_annotations: scanAnnotations,
|
||||
scan_messages: scanMessages,
|
||||
prefer_cache: true,
|
||||
});
|
||||
setScanResult(res);
|
||||
enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, {
|
||||
@@ -267,7 +283,7 @@ export default function AUPScanner() {
|
||||
});
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setScanning(false);
|
||||
}, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||
}, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||
|
||||
if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>;
|
||||
|
||||
@@ -316,9 +332,38 @@ export default function AUPScanner() {
|
||||
)}
|
||||
{!selectedHuntId && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
All datasets will be scanned if no hunt is selected
|
||||
Select a hunt to enable scoped scanning
|
||||
</Typography>
|
||||
)}
|
||||
|
||||
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
|
||||
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
|
||||
<Select
|
||||
labelId="aup-dataset-label"
|
||||
multiple
|
||||
value={Array.from(selectedDs)}
|
||||
label="Datasets"
|
||||
renderValue={(selected) => `${(selected as string[]).length} selected`}
|
||||
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
|
||||
>
|
||||
{dsList.map(d => (
|
||||
<MenuItem key={d.id} value={d.id}>
|
||||
<Checkbox size="small" checked={selectedDs.has(d.id)} />
|
||||
<Typography variant="body2" sx={{ ml: 0.5 }}>
|
||||
{d.name} ({d.row_count.toLocaleString()} rows)
|
||||
</Typography>
|
||||
</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
|
||||
{selectedHuntId && dsList.length > 0 && (
|
||||
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
|
||||
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
|
||||
</Stack>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* Theme selector */}
|
||||
@@ -372,7 +417,7 @@ export default function AUPScanner() {
|
||||
<Button
|
||||
variant="contained" color="warning" size="large"
|
||||
startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />}
|
||||
onClick={runScan} disabled={scanning}
|
||||
onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
|
||||
>
|
||||
{scanning ? 'Scanning…' : 'Run Scan'}
|
||||
</Button>
|
||||
@@ -392,6 +437,15 @@ export default function AUPScanner() {
|
||||
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||
{scanResult.cache_status && (
|
||||
<Chip
|
||||
size="small"
|
||||
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
|
||||
sx={{ ml: 1, height: 20 }}
|
||||
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
|
||||
variant="outlined"
|
||||
/>
|
||||
)}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user