9 Commits

Author SHA1 Message Date
483176c06b chore: checkpoint all local changes 2026-02-23 14:36:33 -05:00
13bd9ec9e0 feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:35:49 -05:00
5a2ad8ec1c feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
2026-02-23 14:23:07 -05:00
37a9584d0c docs: update changelog and add robust dev-up startup script 2026-02-23 14:22:17 -05:00
7c454036c7 Merge origin/main (v0.3.1) into local main (v0.4.0) — keep local versions for all conflicts 2026-02-20 14:35:08 -05:00
365cf87c90 version 0.4.0 2026-02-20 14:32:42 -05:00
bb562a91ca version 0.3.1 2026-02-20 07:16:17 -05:00
04a9946891 feat: host-centric network map, analysis dashboard, deduped inventory
- Rewrote NetworkMap to use deduplicated host inventory (163 hosts from 394K rows)
- New host_inventory.py service: scans datasets, groups by FQDN/ClientId, extracts IPs/users/OS
- New /api/network/host-inventory endpoint
- Added AnalysisDashboard with 6 tabs (IOC, anomaly, host profile, query, triage, reports)
- Added 16 analysis API endpoints with job queue and load balancer
- Added 4 AI/analysis ORM models (ProcessingJob, AnalysisResult, HostProfile, IOCEntry)
- Filters system accounts (DWM-*, UMFD-*, LOCAL/NETWORK SERVICE)
- Infers OS from hostname patterns (W10-* -> Windows 10)
- Canvas 2D force-directed graph with host/external-IP node types
- Click popover shows hostname, FQDN, IPs, OS, users, datasets, connections
2026-02-20 07:16:17 -05:00
ab8038867a version 0.3.0 2026-02-19 15:42:20 -05:00
183 changed files with 48848 additions and 1460 deletions

26
.gitignore vendored
View File

@@ -1,4 +1,4 @@
# ── Python ──────────────────────────────────── # ── Python ────────────────────────────────────
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
@@ -8,34 +8,34 @@ build/
*.egg *.egg
.eggs/ .eggs/
# ── Virtual environments ───────────────────── # ── Virtual environments ─────────────────────
venv/ venv/
.venv/ .venv/
env/ env/
# ── IDE / Editor ───────────────────────────── # ── IDE / Editor ─────────────────────────────
.vscode/ .vscode/
.idea/ .idea/
*.swp *.swp
*.swo *.swo
*~ *~
# ── OS ──────────────────────────────────────── # ── OS ────────────────────────────────────────
.DS_Store .DS_Store
Thumbs.db Thumbs.db
# ── Environment / Secrets ──────────────────── # ── Environment / Secrets ────────────────────
.env .env
*.env.local *.env.local
# ── Database ───────────────────────────────── # ── Database ─────────────────────────────────
*.db *.db
*.sqlite3 *.sqlite3
# ── Uploads ────────────────────────────────── # ── Uploads ──────────────────────────────────
uploads/ uploads/
# ── Node / Frontend ────────────────────────── # ── Node / Frontend ──────────────────────────
node_modules/ node_modules/
frontend/build/ frontend/build/
frontend/.env.local frontend/.env.local
@@ -43,14 +43,18 @@ npm-debug.log*
yarn-debug.log* yarn-debug.log*
yarn-error.log* yarn-error.log*
# ── Docker ─────────────────────────────────── # ── Docker ───────────────────────────────────
docker-compose.override.yml docker-compose.override.yml
# ── Test / Coverage ────────────────────────── # ── Test / Coverage ──────────────────────────
.coverage .coverage
htmlcov/ htmlcov/
.pytest_cache/ .pytest_cache/
.mypy_cache/ .mypy_cache/
# ── Alembic ────────────────────────────────── # ── Alembic ──────────────────────────────────
alembic/versions/*.pyc alembic/versions/*.pyc
*.db-wal
*.db-shm

View File

@@ -0,0 +1 @@
[ 656ms] [WARNING] No routes matched location "/network-map" @ http://localhost:3000/static/js/main.c0a7ab6d.js:1

View File

@@ -0,0 +1 @@
[ 4269ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.6d916bcf.js:1

View File

@@ -0,0 +1 @@
[ 496ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.28ae077d.js:1

View File

@@ -0,0 +1,76 @@
[ 402ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 60389ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
[ 114742ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
[ 116603ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
[ 362021ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 379006ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 379019ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 379021ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
[ 382647ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 386088ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 386343ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 386345ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
[ 397704ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 519009ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 519273ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
[ 519274ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)

View File

@@ -0,0 +1 @@
[ 1803ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.b2c21c5a.js:1

View File

@@ -0,0 +1,48 @@
[ 2196ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 46100ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 46117ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785) @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 46118ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
[ 52506ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 54912ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 54928ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785) @ http://localhost:3000/static/js/main.0e63bc98.js:1
[ 54929ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)

View File

@@ -0,0 +1,7 @@
[ 2548ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
[ 32912ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
[ 55583ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
[ 58208ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
[ 1168933ms] [ERROR] Failed to load resource: the server responded with a status of 504 (Gateway Time-out) @ http://localhost:3000/api/analysis/llm-analyze:0
[ 1477343ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
[ 1482908ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1

View File

@@ -0,0 +1,7 @@
[ 9612ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/:0
[ 17464ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/enterprise:0
[ 20742ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/enterprise:0
[ 53258ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/pricing:0
[ 59240ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/copilot#pricing:0
[ 67668ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/spark?utm_source=web-copilot-ce-cta&utm_campaign=spark-launch-sep-2025:0
[ 72166ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/spark?utm_source=web-copilot-ce-cta&utm_campaign=spark-launch-sep-2025:0

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 558 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 607 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 341 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 193 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 184 KiB

View File

@@ -17,7 +17,7 @@ COPY frontend/tsconfig.json ./
# Build application # Build application
RUN npm run build RUN npm run build
# Production stage nginx reverse-proxy + static files # Production stage — nginx reverse-proxy + static files
FROM nginx:alpine FROM nginx:alpine
# Copy built React app # Copy built React app

View File

@@ -0,0 +1,148 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# 1) Add label mode type near graph types
marker="interface GEdge { source: string; target: string; weight: number }\ninterface Graph { nodes: GNode[]; edges: GEdge[] }\n"
if marker in t and "type LabelMode" not in t:
t=t.replace(marker, marker+"\ntype LabelMode = 'all' | 'highlight' | 'none';\n")
# 2) extend drawLabels signature
old_sig="""function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
"""
new_sig="""function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean, labelMode: LabelMode,
) {
"""
if old_sig in t:
t=t.replace(old_sig,new_sig)
# 3) label mode guards inside drawLabels
old_guard=""" const dimmed = search.length > 0;
if (simplify && !search && !hovered && !selected) {
return;
}
"""
new_guard=""" if (labelMode === 'none') return;
const dimmed = search.length > 0;
if (labelMode === 'highlight' && !search && !hovered && !selected) return;
if (simplify && labelMode !== 'all' && !search && !hovered && !selected) {
return;
}
"""
if old_guard in t:
t=t.replace(old_guard,new_guard)
old_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
const show = isHighlight || n.meta.type === 'host' || n.count >= 2;
if (!show) continue;
"""
new_show=""" const isHighlight = hovered === n.id || selected === n.id || matchSet.has(n.id);
const show = labelMode === 'all'
? (isHighlight || n.meta.type === 'host' || n.count >= 2)
: isHighlight;
if (!show) continue;
"""
if old_show in t:
t=t.replace(old_show,new_show)
# 4) drawGraph signature and call site
old_graph_sig="""function drawGraph(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null, search: string,
vp: Viewport, animTime: number, dpr: number,
) {
"""
new_graph_sig="""function drawGraph(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null, search: string,
vp: Viewport, animTime: number, dpr: number, labelMode: LabelMode,
) {
"""
if old_graph_sig in t:
t=t.replace(old_graph_sig,new_graph_sig)
old_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);"
new_drawlabels_call="drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify, labelMode);"
if old_drawlabels_call in t:
t=t.replace(old_drawlabels_call,new_drawlabels_call)
# 5) state for label mode
state_anchor=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n"
state_new=" const [selectedNode, setSelectedNode] = useState<GNode | null>(null);\n const [search, setSearch] = useState('');\n const [labelMode, setLabelMode] = useState<LabelMode>('highlight');\n"
if state_anchor in t:
t=t.replace(state_anchor,state_new)
# 6) pass labelMode in draw calls
old_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr);"
new_tick_draw="drawGraph(ctx, g, hoveredRef.current, selectedNodeRef.current?.id ?? null, searchRef.current, vpRef.current, ts, dpr, labelMode);"
if old_tick_draw in t:
t=t.replace(old_tick_draw,new_tick_draw)
old_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr);"
new_redraw_draw="if (ctx) drawGraph(ctx, graph, hovered, selectedNode?.id ?? null, search, vpRef.current, animTimeRef.current, dpr, labelMode);"
if old_redraw_draw in t:
t=t.replace(old_redraw_draw,new_redraw_draw)
# 7) include labelMode in redraw deps
old_redraw_dep="] , [graph, hovered, selectedNode, search]);"
if old_redraw_dep in t:
t=t.replace(old_redraw_dep, "] , [graph, hovered, selectedNode, search, labelMode]);")
else:
t=t.replace(" }, [graph, hovered, selectedNode, search]);"," }, [graph, hovered, selectedNode, search, labelMode]);")
# 8) Add toolbar selector after search field
search_block=""" <TextField
size="small"
placeholder="Search hosts, IPs, users\u2026"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
"""
label_block=""" <TextField
size="small"
placeholder="Search hosts, IPs, users\u2026"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
<FormControl size="small" sx={{ minWidth: 140 }}>
<InputLabel id="label-mode-selector">Labels</InputLabel>
<Select
labelId="label-mode-selector"
value={labelMode}
label="Labels"
onChange={e => setLabelMode(e.target.value as LabelMode)}
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
>
<MenuItem value="none">None</MenuItem>
<MenuItem value="highlight">Selected/Search</MenuItem>
<MenuItem value="all">All</MenuItem>
</Select>
</FormControl>
"""
if search_block in t:
t=t.replace(search_block,label_block)
p.write_text(t,encoding='utf-8')
print('added network map label filter control and renderer modes')

View File

@@ -0,0 +1,18 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' # -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
'''
new=''' # -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=300000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
if old not in t:
raise SystemExit('scanner settings block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('added SCANNER_MAX_ROWS_PER_SCAN config')

View File

@@ -0,0 +1,46 @@
from pathlib import Path
root = Path(r"d:\Projects\Dev\ThreatHunt")
# -------- client.ts --------
client = root / "frontend/src/api/client.ts"
text = client.read_text(encoding="utf-8")
if "export interface NetworkSummary" not in text:
insert_after = "export interface InventoryStatus {\n hunt_id: string;\n status: 'ready' | 'building' | 'none';\n}\n"
addition = insert_after + "\nexport interface NetworkSummaryHost {\n id: string;\n hostname: string;\n row_count: number;\n ip_count: number;\n user_count: number;\n}\n\nexport interface NetworkSummary {\n stats: InventoryStats;\n top_hosts: NetworkSummaryHost[];\n top_edges: InventoryConnection[];\n status?: 'building' | 'deferred';\n message?: string;\n}\n"
text = text.replace(insert_after, addition)
net_old = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
net_new = """export const network = {\n hostInventory: (huntId: string, force = false) =>\n api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(`/api/network/host-inventory?hunt_id=${encodeURIComponent(huntId)}${force ? '&force=true' : ''}`),\n summary: (huntId: string, topN = 20) =>\n api<NetworkSummary | { status: 'building' | 'deferred'; message?: string }>(`/api/network/summary?hunt_id=${encodeURIComponent(huntId)}&top_n=${topN}`),\n subgraph: (huntId: string, maxHosts = 250, maxEdges = 1500, nodeId?: string) => {\n let qs = `/api/network/subgraph?hunt_id=${encodeURIComponent(huntId)}&max_hosts=${maxHosts}&max_edges=${maxEdges}`;\n if (nodeId) qs += `&node_id=${encodeURIComponent(nodeId)}`;\n return api<HostInventory | { status: 'building' | 'deferred'; message?: string }>(qs);\n },\n inventoryStatus: (huntId: string) =>\n api<InventoryStatus>(`/api/network/inventory-status?hunt_id=${encodeURIComponent(huntId)}`),\n rebuildInventory: (huntId: string) =>\n api<{ job_id: string; status: string }>(`/api/network/rebuild-inventory?hunt_id=${encodeURIComponent(huntId)}`, { method: 'POST' }),\n};"""
if net_old in text:
text = text.replace(net_old, net_new)
client.write_text(text, encoding="utf-8")
# -------- NetworkMap.tsx --------
nm = root / "frontend/src/components/NetworkMap.tsx"
text = nm.read_text(encoding="utf-8")
# add constants
if "LARGE_HUNT_HOST_THRESHOLD" not in text:
text = text.replace("let lastSelectedHuntId = '';\n", "let lastSelectedHuntId = '';\nconst LARGE_HUNT_HOST_THRESHOLD = 400;\nconst LARGE_HUNT_SUBGRAPH_HOSTS = 350;\nconst LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n")
# inject helper in component after sleep
marker = " const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
if "loadScaleAwareGraph" not in text:
helper = marker + "\n const loadScaleAwareGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n setLoading(true); setError(''); setGraph(null); setStats(null);\n setSelectedNode(null); setPopoverAnchor(null);\n\n const waitReadyThen = async <T,>(fn: () => Promise<T>): Promise<T> => {\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const out: any = await fn();\n if (out && !out.status) return out as T;\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') {\n const out2: any = await fn();\n if (out2 && !out2.status) return out2 as T;\n }\n if (Date.now() - startedAt > 5 * 60 * 1000) throw new Error('Network data build timed out after 5 minutes');\n const jitter = Math.floor(Math.random() * 250);\n await sleep(delayMs + jitter);\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n };\n\n try {\n setProgress('Loading network summary');\n const summary: any = await waitReadyThen(() => network.summary(huntId, 20));\n const totalHosts = summary?.stats?.total_hosts || 0;\n\n if (totalHosts > LARGE_HUNT_HOST_THRESHOLD) {\n setProgress(`Large hunt detected (${totalHosts} hosts). Loading focused subgraph`);\n const sub: any = await waitReadyThen(() => network.subgraph(huntId, LARGE_HUNT_SUBGRAPH_HOSTS, LARGE_HUNT_SUBGRAPH_EDGES));\n if (!sub?.hosts || sub.hosts.length === 0) {\n setError('No hosts found for subgraph.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(sub.hosts, sub.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats, ts: Date.now() });\n setGraph(g);\n return;\n }\n\n // Small/medium hunts: load full inventory\n setProgress('Loading host inventory');\n const inv: any = await waitReadyThen(() => network.hostInventory(huntId, forceRefresh));\n if (!inv?.hosts || inv.hosts.length === 0) {\n setError('No hosts found. Upload CSV files with host-identifying columns (ClientId, Fqdn, Hostname) to this hunt.');\n return;\n }\n const { w, h } = canvasSizeRef.current;\n const g = buildGraphFromInventory(inv.hosts, inv.connections || [], w, h);\n simulate(g, w / 2, h / 2, 60);\n simAlphaRef.current = 0.3;\n setStats(summary.stats || inv.stats);\n graphCache.set(huntId, { graph: g, stats: summary.stats || inv.stats, ts: Date.now() });\n setGraph(g);\n } catch (e: any) {\n console.error('[NetworkMap] scale-aware load error:', e);\n setError(e.message || 'Failed to load network data');\n } finally {\n setLoading(false);\n setProgress('');\n }\n }, []);\n"
text = text.replace(marker, helper)
# simplify existing loadGraph function body to delegate
pattern_start = text.find(" // Load host inventory for selected hunt (with cache).")
if pattern_start != -1:
# replace the whole loadGraph useCallback block by simple delegator
import re
block_re = re.compile(r" // Load host inventory for selected hunt \(with cache\)\.[\s\S]*?\n \}, \[\]\); // Stable - reads canvasSizeRef, no state deps\n", re.M)
repl = " // Load graph data for selected hunt (delegates to scale-aware loader).\n const loadGraph = useCallback(async (huntId: string, forceRefresh = false) => {\n if (!huntId) return;\n\n // Check module-level cache first (5 min TTL)\n if (!forceRefresh) {\n const cached = graphCache.get(huntId);\n if (cached && Date.now() - cached.ts < 5 * 60 * 1000) {\n setGraph(cached.graph);\n setStats(cached.stats);\n setError('');\n simAlphaRef.current = 0;\n return;\n }\n }\n\n await loadScaleAwareGraph(huntId, forceRefresh);\n // eslint-disable-next-line react-hooks/exhaustive-deps\n }, []); // Stable - reads canvasSizeRef, no state deps\n"
text = block_re.sub(repl, text, count=1)
nm.write_text(text, encoding="utf-8")
print("Patched frontend client + NetworkMap for scale-aware loading")

206
_apply_phase1_patch.py Normal file
View File

@@ -0,0 +1,206 @@
from pathlib import Path
root = Path(r"d:\Projects\Dev\ThreatHunt")
# 1) config.py additions
cfg = root / "backend/app/config.py"
text = cfg.read_text(encoding="utf-8")
needle = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n"
insert = " # -- Scanner settings -----------------------------------------------\n SCANNER_BATCH_SIZE: int = Field(default=500, description=\"Rows per scanner batch\")\n\n # -- Job queue settings ----------------------------------------------\n JOB_QUEUE_MAX_BACKLOG: int = Field(\n default=2000, description=\"Soft cap for queued background jobs\"\n )\n JOB_QUEUE_RETAIN_COMPLETED: int = Field(\n default=3000, description=\"Maximum completed/failed jobs to retain in memory\"\n )\n JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(\n default=60, description=\"How often to run in-memory job cleanup\"\n )\n JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
if needle in text:
text = text.replace(needle, insert)
cfg.write_text(text, encoding="utf-8")
# 2) scanner.py default scope = dataset-only
scanner = root / "backend/app/services/scanner.py"
text = scanner.read_text(encoding="utf-8")
text = text.replace(" scan_hunts: bool = True,", " scan_hunts: bool = False,")
text = text.replace(" scan_annotations: bool = True,", " scan_annotations: bool = False,")
text = text.replace(" scan_messages: bool = True,", " scan_messages: bool = False,")
scanner.write_text(text, encoding="utf-8")
# 3) keywords.py defaults = dataset-only
kw = root / "backend/app/api/routes/keywords.py"
text = kw.read_text(encoding="utf-8")
text = text.replace(" scan_hunts: bool = True", " scan_hunts: bool = False")
text = text.replace(" scan_annotations: bool = True", " scan_annotations: bool = False")
text = text.replace(" scan_messages: bool = True", " scan_messages: bool = False")
kw.write_text(text, encoding="utf-8")
# 4) job_queue.py dedupe + periodic cleanup
jq = root / "backend/app/services/job_queue.py"
text = jq.read_text(encoding="utf-8")
text = text.replace(
"from typing import Any, Callable, Coroutine, Optional\n",
"from typing import Any, Callable, Coroutine, Optional\n\nfrom app.config import settings\n"
)
text = text.replace(
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n",
" self._completion_callbacks: list[Callable[[Job], Coroutine]] = []\n self._cleanup_task: asyncio.Task | None = None\n"
)
start_old = ''' async def start(self):
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
logger.info(f"Job queue started with {self._max_workers} workers")
'''
start_new = ''' async def start(self):
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(f"Job queue started with {self._max_workers} workers")
'''
text = text.replace(start_old, start_new)
stop_old = ''' async def stop(self):
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("Job queue stopped")
'''
stop_new = ''' async def stop(self):
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
if self._cleanup_task:
self._cleanup_task.cancel()
await asyncio.gather(self._cleanup_task, return_exceptions=True)
self._cleanup_task = None
logger.info("Job queue stopped")
'''
text = text.replace(stop_old, stop_new)
submit_old = ''' def submit(self, job_type: JobType, **params) -> Job:
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
return job
'''
submit_new = ''' def submit(self, job_type: JobType, **params) -> Job:
# Soft backpressure: prefer dedupe over queue amplification
dedupe_job = self._find_active_duplicate(job_type, params)
if dedupe_job is not None:
logger.info(
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
)
return dedupe_job
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
logger.warning(
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
)
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
return job
'''
text = text.replace(submit_old, submit_new)
insert_methods_after = " def get_job(self, job_id: str) -> Job | None:\n return self._jobs.get(job_id)\n"
new_methods = ''' def get_job(self, job_id: str) -> Job | None:
return self._jobs.get(job_id)
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
"""Return queued/running job with same key workload to prevent duplicate storms."""
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
if not sig:
return None
for j in self._jobs.values():
if j.job_type != job_type:
continue
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
continue
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
if sig == other_sig:
return j
return None
'''
text = text.replace(insert_methods_after, new_methods)
cleanup_old = ''' def cleanup(self, max_age_seconds: float = 3600):
now = time.time()
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
and (now - j.created_at) > max_age_seconds
]
for jid in to_remove:
del self._jobs[jid]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old jobs")
'''
cleanup_new = ''' def cleanup(self, max_age_seconds: float = 3600):
now = time.time()
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
]
# Also cap retained terminal jobs to avoid unbounded memory growth
terminal_jobs = sorted(
[j for j in self._jobs.values() if j.status in terminal_states],
key=lambda j: j.created_at,
reverse=True,
)
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
to_remove.extend([j.id for j in overflow])
removed = 0
for jid in set(to_remove):
if jid in self._jobs:
del self._jobs[jid]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
async def _cleanup_loop(self):
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
while self._started:
try:
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
except Exception as e:
logger.warning(f"Job queue cleanup loop error: {e}")
await asyncio.sleep(interval)
'''
text = text.replace(cleanup_old, cleanup_new)
jq.write_text(text, encoding="utf-8")
# 5) NetworkMap polling backoff/jitter max wait
nm = root / "frontend/src/components/NetworkMap.tsx"
text = nm.read_text(encoding="utf-8")
text = text.replace(
" // Poll until ready, then re-fetch\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n }\n",
" // Poll until ready (exponential backoff), then re-fetch\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n const st = await network.inventoryStatus(huntId);\n if (st.status === 'ready') break;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n throw new Error('Host inventory build timed out after 5 minutes');\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n"
)
text = text.replace(
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status every 2s until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n for (;;) {\n await new Promise(r => setTimeout(r, 2000));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n // still building or none (job may not have started yet) - keep polling\n } catch { if (cancelled) return false; }\n }\n };\n",
" const waitUntilReady = async (): Promise<boolean> => {\n // Poll inventory-status with exponential backoff until 'ready' (or cancelled)\n setProgress('Host inventory is being prepared in the background');\n setLoading(true);\n let delayMs = 1500;\n const startedAt = Date.now();\n for (;;) {\n const jitter = Math.floor(Math.random() * 250);\n await new Promise(r => setTimeout(r, delayMs + jitter));\n if (cancelled) return false;\n try {\n const st = await network.inventoryStatus(selectedHuntId);\n if (cancelled) return false;\n if (st.status === 'ready') return true;\n if (Date.now() - startedAt > 5 * 60 * 1000) {\n setError('Host inventory build timed out. Please retry.');\n return false;\n }\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n // still building or none (job may not have started yet) - keep polling\n } catch {\n if (cancelled) return false;\n delayMs = Math.min(10000, Math.floor(delayMs * 1.5));\n }\n }\n };\n"
)
nm.write_text(text, encoding="utf-8")
print("Patched: config.py, scanner.py, keywords.py, job_queue.py, NetworkMap.tsx")

207
_apply_phase2_patch.py Normal file
View File

@@ -0,0 +1,207 @@
from pathlib import Path
import re
root = Path(r"d:\Projects\Dev\ThreatHunt")
# ---------- config.py ----------
cfg = root / "backend/app/config.py"
text = cfg.read_text(encoding="utf-8")
marker = " JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(\n default=3600, description=\"Age threshold for in-memory completed job cleanup\"\n )\n"
add = marker + "\n # -- Startup throttling ------------------------------------------------\n STARTUP_WARMUP_MAX_HUNTS: int = Field(\n default=5, description=\"Max hunts to warm inventory cache for at startup\"\n )\n STARTUP_REPROCESS_MAX_DATASETS: int = Field(\n default=25, description=\"Max unprocessed datasets to enqueue at startup\"\n )\n\n # -- Network API scale guards -----------------------------------------\n NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(\n default=400, description=\"Hard cap for hosts returned by network subgraph endpoint\"\n )\n NETWORK_SUBGRAPH_MAX_EDGES: int = Field(\n default=3000, description=\"Hard cap for edges returned by network subgraph endpoint\"\n )\n"
if marker in text and "STARTUP_WARMUP_MAX_HUNTS" not in text:
text = text.replace(marker, add)
cfg.write_text(text, encoding="utf-8")
# ---------- job_queue.py ----------
jq = root / "backend/app/services/job_queue.py"
text = jq.read_text(encoding="utf-8")
# add helper methods after get_stats
anchor = " def get_stats(self) -> dict:\n by_status = {}\n for j in self._jobs.values():\n by_status[j.status.value] = by_status.get(j.status.value, 0) + 1\n return {\n \"total\": len(self._jobs),\n \"queued\": self._queue.qsize(),\n \"by_status\": by_status,\n \"workers\": self._max_workers,\n \"active_workers\": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),\n }\n"
if "def is_backlogged(" not in text:
insert = anchor + "\n def is_backlogged(self) -> bool:\n return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG\n\n def can_accept(self, reserve: int = 0) -> bool:\n return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG\n"
text = text.replace(anchor, insert)
jq.write_text(text, encoding="utf-8")
# ---------- host_inventory.py keyset pagination ----------
hi = root / "backend/app/services/host_inventory.py"
text = hi.read_text(encoding="utf-8")
old = ''' batch_size = 5000
offset = 0
while True:
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.offset(offset).limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
'''
new = ''' batch_size = 5000
last_row_index = -1
while True:
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index)
.limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
'''
if old in text:
text = text.replace(old, new)
text = text.replace(" offset += batch_size\n if len(rows) < batch_size:\n break\n", " last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n")
hi.write_text(text, encoding="utf-8")
# ---------- network.py add summary/subgraph + backpressure ----------
net = root / "backend/app/api/routes/network.py"
text = net.read_text(encoding="utf-8")
text = text.replace("from fastapi import APIRouter, Depends, HTTPException, Query", "from fastapi import APIRouter, Depends, HTTPException, Query")
if "from app.config import settings" not in text:
text = text.replace("from app.db import get_db\n", "from app.config import settings\nfrom app.db import get_db\n")
# add helpers and endpoints before inventory-status endpoint
if "def _build_summary" not in text:
helper_block = '''
def _build_summary(inv: dict, top_n: int = 20) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
return {
"stats": inv.get("stats", {}),
"top_hosts": [
{
"id": h.get("id"),
"hostname": h.get("hostname"),
"row_count": h.get("row_count", 0),
"ip_count": len(h.get("ips", [])),
"user_count": len(h.get("users", [])),
}
for h in top_hosts
],
"top_edges": top_edges,
}
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
if node_id:
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
ids = {node_id}
for c in rel_edges:
ids.add(c.get("source"))
ids.add(c.get("target"))
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
else:
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
allowed = {h.get("id") for h in rel_hosts}
rel_edges = [
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
if c.get("source") in allowed and c.get("target") in allowed
][:max_edges]
return {
"hosts": rel_hosts,
"connections": rel_edges,
"stats": {
**inv.get("stats", {}),
"subgraph_hosts": len(rel_hosts),
"subgraph_connections": len(rel_edges),
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
},
}
@router.get("/summary")
async def get_inventory_summary(
hunt_id: str = Query(..., description="Hunt ID"),
top_n: int = Query(20, ge=1, le=200),
):
"""Return a lightweight summary view for large hunts."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_summary(cached, top_n=top_n)
@router.get("/subgraph")
async def get_inventory_subgraph(
hunt_id: str = Query(..., description="Hunt ID"),
node_id: str | None = Query(None, description="Optional focal node"),
max_hosts: int = Query(200, ge=1, le=5000),
max_edges: int = Query(1500, ge=1, le=20000),
):
"""Return a bounded subgraph for scale-safe rendering."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
'''
text = text.replace("\n\n@router.get(\"/inventory-status\")", helper_block + "\n\n@router.get(\"/inventory-status\")")
# add backpressure in host-inventory enqueue points
text = text.replace(
" if not inventory_cache.is_building(hunt_id):\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
" if not inventory_cache.is_building(hunt_id):\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
)
text = text.replace(
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)",
" if not inventory_cache.is_building(hunt_id):\n logger.info(f\"Cache miss for {hunt_id}, triggering background build\")\n if job_queue.is_backlogged():\n return JSONResponse(status_code=202, content={\"status\": \"deferred\", \"message\": \"Queue busy, retry shortly\"})\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)"
)
net.write_text(text, encoding="utf-8")
# ---------- analysis.py backpressure on manual submit ----------
analysis = root / "backend/app/api/routes/analysis.py"
text = analysis.read_text(encoding="utf-8")
text = text.replace(
" job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}",
" if not job_queue.can_accept():\n raise HTTPException(status_code=429, detail=\"Job queue is busy. Retry shortly.\")\n job = job_queue.submit(jt, **params)\n return {\"job_id\": job.id, \"status\": job.status.value, \"job_type\": job_type}"
)
analysis.write_text(text, encoding="utf-8")
# ---------- main.py startup throttles ----------
main = root / "backend/app/main.py"
text = main.read_text(encoding="utf-8")
text = text.replace(
" for hid in hunt_ids:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if hunt_ids:\n logger.info(f\"Queued host inventory warm-up for {len(hunt_ids)} hunts\")",
" warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]\n for hid in warm_hunts:\n job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)\n if warm_hunts:\n logger.info(f\"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})\")"
)
text = text.replace(
" if unprocessed_ids:\n for ds_id in unprocessed_ids:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(unprocessed_ids)} unprocessed datasets\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(unprocessed_ids))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()",
" if unprocessed_ids:\n to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]\n for ds_id in to_reprocess:\n job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)\n job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)\n job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)\n job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)\n logger.info(f\"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})\")\n async with async_session_factory() as update_db:\n from sqlalchemy import update\n from app.db.models import Dataset\n await update_db.execute(\n update(Dataset)\n .where(Dataset.id.in_(to_reprocess))\n .values(processing_status=\"processing\")\n )\n await update_db.commit()"
)
main.write_text(text, encoding="utf-8")
print("Patched Phase 2 files")

View File

@@ -0,0 +1,75 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
# default selection when hunt changes: first 3 datasets instead of all
old=''' datasets.list(0, 500, selectedHuntId).then(res => {
if (cancelled) return;
setDsList(res.datasets);
setSelectedDs(new Set(res.datasets.map(d => d.id)));
}).catch(() => {});
'''
new=''' datasets.list(0, 500, selectedHuntId).then(res => {
if (cancelled) return;
setDsList(res.datasets);
setSelectedDs(new Set(res.datasets.slice(0, 3).map(d => d.id)));
}).catch(() => {});
'''
if old not in t:
raise SystemExit('hunt-change dataset init block not found')
t=t.replace(old,new)
# insert dataset scope multi-select under hunt info
anchor=''' {!selectedHuntId && (
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
All datasets will be scanned if no hunt is selected
</Typography>
)}
</Box>
{/* Theme selector */}
'''
insert=''' {!selectedHuntId && (
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
Select a hunt to enable scoped scanning
</Typography>
)}
<FormControl size="small" fullWidth sx={{ mt: 1.2 }} disabled={!selectedHuntId || dsList.length === 0}>
<InputLabel id="aup-dataset-label">Datasets</InputLabel>
<Select
labelId="aup-dataset-label"
multiple
value={Array.from(selectedDs)}
label="Datasets"
renderValue={(selected) => `${(selected as string[]).length} selected`}
onChange={(e) => setSelectedDs(new Set(e.target.value as string[]))}
>
{dsList.map(d => (
<MenuItem key={d.id} value={d.id}>
<Checkbox size="small" checked={selectedDs.has(d.id)} />
<Typography variant="body2" sx={{ ml: 0.5 }}>
{d.name} ({d.row_count.toLocaleString()} rows)
</Typography>
</MenuItem>
))}
</Select>
</FormControl>
{selectedHuntId && dsList.length > 0 && (
<Stack direction="row" spacing={1} sx={{ mt: 1 }}>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.slice(0, 3).map(d => d.id)))}>Top 3</Button>
<Button size="small" onClick={() => setSelectedDs(new Set(dsList.map(d => d.id)))}>All</Button>
<Button size="small" onClick={() => setSelectedDs(new Set())}>Clear</Button>
</Stack>
)}
</Box>
{/* Theme selector */}
'''
if anchor not in t:
raise SystemExit('dataset scope anchor not found')
t=t.replace(anchor,insert)
p.write_text(t,encoding='utf-8')
print('added AUP dataset multi-select scoping and safer defaults')

View File

@@ -0,0 +1,182 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
# 1) Extend ScanHit dataclass
old='''@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
'''
new='''@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
'''
if old not in t:
raise SystemExit('ScanHit dataclass block not found')
t=t.replace(old,new)
# 2) Add helper to infer hostname/user from a row
insert_after='''BATCH_SIZE = 200
@dataclass
class ScanHit:
'''
helper='''BATCH_SIZE = 200
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
"""Best-effort extraction of hostname and user from a dataset row."""
if not data:
return None, None
host_keys = (
'hostname', 'host_name', 'host', 'computer_name', 'computer',
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
)
user_keys = (
'username', 'user_name', 'user', 'account_name',
'logged_in_user', 'samaccountname', 'sam_account_name',
)
def pick(keys):
for k in keys:
for actual_key, v in data.items():
if actual_key.lower() == k and v not in (None, ''):
return str(v)
return None
return pick(host_keys), pick(user_keys)
@dataclass
class ScanHit:
'''
if insert_after in t and '_infer_hostname_and_user' not in t:
t=t.replace(insert_after,helper)
# 3) Extend _match_text signature and ScanHit construction
old_sig=''' def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
) -> None:
'''
new_sig=''' def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
hostname: str | None = None,
username: str | None = None,
) -> None:
'''
if old_sig not in t:
raise SystemExit('_match_text signature not found')
t=t.replace(old_sig,new_sig)
old_hit=''' hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
))
'''
new_hit=''' hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
hostname=hostname,
username=username,
))
'''
if old_hit not in t:
raise SystemExit('ScanHit append block not found')
t=t.replace(old_hit,new_hit)
# 4) Pass inferred hostname/username in dataset scan path
old_call=''' for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
)
'''
new_call=''' for row in rows:
result.rows_scanned += 1
data = row.data or {}
hostname, username = _infer_hostname_and_user(data)
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
hostname=hostname,
username=username,
)
'''
if old_call not in t:
raise SystemExit('dataset _match_text call block not found')
t=t.replace(old_call,new_call)
p.write_text(t,encoding='utf-8')
print('updated scanner hits with hostname+username context')

View File

@@ -0,0 +1,32 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old='''class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
'''
new='''class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
'''
if old not in t:
raise SystemExit('ScanHit pydantic model block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('extended API ScanHit model with hostname+username')

View File

@@ -0,0 +1,21 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
t=p.read_text(encoding='utf-8')
old='''export interface ScanHit {
theme_name: string; theme_color: string; keyword: string;
source_type: string; source_id: string | number; field: string;
matched_value: string; row_index: number | null; dataset_name: string | null;
}
'''
new='''export interface ScanHit {
theme_name: string; theme_color: string; keyword: string;
source_type: string; source_id: string | number; field: string;
matched_value: string; row_index: number | null; dataset_name: string | null;
hostname?: string | null; username?: string | null;
}
'''
if old not in t:
raise SystemExit('frontend ScanHit interface block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('extended frontend ScanHit type with hostname+username')

View File

@@ -0,0 +1,57 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
# add fast guard against unscoped global dataset scans
insert_after='''async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):\n scanner = KeywordScanner(db)\n\n'''
if insert_after not in t:
raise SystemExit('run_scan header block not found')
if 'Select at least one dataset' not in t:
guard=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:\n raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")\n\n'''
t=t.replace(insert_after, insert_after+guard)
old=''' if missing:
missing_entries: list[dict] = []
for dataset_id in missing:
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
keyword_scan_cache.put(dataset_id, partial)
missing_entries.append({"result": partial, "built_at": None})
merged = _merge_cached_results(
cached_entries + missing_entries,
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
new=''' if missing:
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
merged = _merge_cached_results(
cached_entries + [{"result": partial, "built_at": None}],
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
if old not in t:
raise SystemExit('partial-cache missing block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('hardened keywords scan scope + optimized missing-cache path')

18
_aup_reduce_budget.py Normal file
View File

@@ -0,0 +1,18 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=300000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
new=''' SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=120000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
'''
if old not in t:
raise SystemExit('SCANNER_MAX_ROWS_PER_SCAN block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('reduced SCANNER_MAX_ROWS_PER_SCAN default to 120000')

View File

@@ -0,0 +1,42 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
old='''const RESULT_COLUMNS: GridColDef[] = [
{
field: 'theme_name', headerName: 'Theme', width: 140,
renderCell: (params) => (
<Chip label={params.value} size="small"
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
),
},
{ field: 'keyword', headerName: 'Keyword', width: 140 },
{ field: 'source_type', headerName: 'Source', width: 120 },
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
{ field: 'field', headerName: 'Field', width: 130 },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
];
'''
new='''const RESULT_COLUMNS: GridColDef[] = [
{
field: 'theme_name', headerName: 'Theme', width: 140,
renderCell: (params) => (
<Chip label={params.value} size="small"
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
),
},
{ field: 'keyword', headerName: 'Keyword', width: 140 },
{ field: 'dataset_name', headerName: 'Dataset', width: 170 },
{ field: 'hostname', headerName: 'Hostname', width: 170, valueGetter: (v, row) => row.hostname || '' },
{ field: 'username', headerName: 'User', width: 160, valueGetter: (v, row) => row.username || '' },
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 220 },
{ field: 'field', headerName: 'Field', width: 130 },
{ field: 'source_type', headerName: 'Source', width: 120 },
{ field: 'row_index', headerName: 'Row #', width: 90, type: 'number' },
];
'''
if old not in t:
raise SystemExit('RESULT_COLUMNS block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated AUP results grid columns with dataset/hostname/user/matched value focus')

40
_edit_aup.py Normal file
View File

@@ -0,0 +1,40 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
t=t.replace(' const [scanHunts, setScanHunts] = useState(true);',' const [scanHunts, setScanHunts] = useState(false);')
t=t.replace(' const [scanAnnotations, setScanAnnotations] = useState(true);',' const [scanAnnotations, setScanAnnotations] = useState(false);')
t=t.replace(' const [scanMessages, setScanMessages] = useState(true);',' const [scanMessages, setScanMessages] = useState(false);')
t=t.replace(' scan_messages: scanMessages,\n });',' scan_messages: scanMessages,\n prefer_cache: true,\n });')
# add cache chip in summary alert
old=''' {scanResult && (
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
<strong>{scanResult.total_hits}</strong> hits across{' '}
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
</Alert>
)}
'''
new=''' {scanResult && (
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
<strong>{scanResult.total_hits}</strong> hits across{' '}
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
{scanResult.cache_status && (
<Chip
size="small"
label={scanResult.cache_status === 'hit' ? 'Cached' : 'Live'}
sx={{ ml: 1, height: 20 }}
color={scanResult.cache_status === 'hit' ? 'success' : 'default'}
variant="outlined"
/>
)}
</Alert>
)}
'''
if old in t:
t=t.replace(old,new)
else:
print('warning: summary block not replaced')
p.write_text(t,encoding='utf-8')
print('updated AUPScanner.tsx')

36
_edit_client.py Normal file
View File

@@ -0,0 +1,36 @@
from pathlib import Path
import re
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/api/client.ts')
t=p.read_text(encoding='utf-8')
# Add HuntProgress interface after Hunt interface
if 'export interface HuntProgress' not in t:
insert = '''export interface HuntProgress {
hunt_id: string;
status: 'idle' | 'processing' | 'ready';
progress_percent: number;
dataset_total: number;
dataset_completed: number;
dataset_processing: number;
dataset_errors: number;
active_jobs: number;
queued_jobs: number;
network_status: 'none' | 'building' | 'ready';
stages: Record<string, any>;
}
'''
t=t.replace('export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n', 'export interface Hunt {\n id: string; name: string; description: string | null; status: string;\n owner_id: string | null; created_at: string; updated_at: string;\n dataset_count: number; hypothesis_count: number;\n}\n\n'+insert)
# Add hunts.progress method
if 'progress: (id: string)' not in t:
t=t.replace(" delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n};", " delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),\n progress: (id: string) => api<HuntProgress>(`/api/hunts/${id}/progress`),\n};")
# Extend ScanResponse
if 'cache_used?: boolean' not in t:
t=t.replace('export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n}\n', 'export interface ScanResponse {\n total_hits: number; hits: ScanHit[]; themes_scanned: number;\n keywords_scanned: number; rows_scanned: number;\n cache_used?: boolean; cache_status?: string; cached_at?: string | null;\n}\n')
# Extend keywords.scan opts
t=t.replace(' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n }) =>', ' scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;\n prefer_cache?: boolean; force_rescan?: boolean;\n }) =>')
p.write_text(t,encoding='utf-8')
print('updated client.ts')

20
_edit_config_reconcile.py Normal file
View File

@@ -0,0 +1,20 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
anchor=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
'''
insert=''' STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
default=True,
description="Mark stale queued/running processing tasks as failed on startup",
)
'''
if anchor not in t:
raise SystemExit('startup anchor not found')
t=t.replace(anchor,insert)
p.write_text(t,encoding='utf-8')
print('updated config with STARTUP_RECONCILE_STALE_TASKS')

39
_edit_datasets.py Normal file
View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'from app.services.scanner import keyword_scan_cache' not in t:
t=t.replace('from app.services.host_inventory import inventory_cache','from app.services.host_inventory import inventory_cache\nfrom app.services.scanner import keyword_scan_cache')
old='''@router.delete(
"/{dataset_id}",
summary="Delete a dataset",
)
async def delete_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
return {"message": "Dataset deleted", "id": dataset_id}
'''
new='''@router.delete(
"/{dataset_id}",
summary="Delete a dataset",
)
async def delete_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id}
'''
if old not in t:
raise SystemExit('delete block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated datasets.py')

110
_edit_datasets_tasks.py Normal file
View File

@@ -0,0 +1,110 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'ProcessingTask' not in t:
t=t.replace('from app.db.models import', 'from app.db.models import ProcessingTask\n# from app.db.models import')
t=t.replace('from app.services.scanner import keyword_scan_cache','from app.services.scanner import keyword_scan_cache')
# clean import replacement to proper single line
if '# from app.db.models import' in t:
t=t.replace('from app.db.models import ProcessingTask\n# from app.db.models import', 'from app.db.models import ProcessingTask')
old=''' # 1. AI Triage (chains to HOST_PROFILE automatically on completion)
job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
# 2. Anomaly detection (embedding-based outlier detection)
job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
# 3. AUP keyword scan
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
# 4. IOC extraction
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
'''
new=''' task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
'''
if old not in t:
raise SystemExit('queue block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated datasets upload queue + processing tasks')

254
_edit_hunts.py Normal file
View File

@@ -0,0 +1,254 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
new='''"""API routes for hunt management."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Hunt, Dataset
from app.services.job_queue import job_queue
from app.services.host_inventory import inventory_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
class HuntCreate(BaseModel):
name: str = Field(..., max_length=256)
description: str | None = None
class HuntUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None
class HuntResponse(BaseModel):
id: str
name: str
description: str | None
status: str
owner_id: str | None
created_at: str
updated_at: str
dataset_count: int = 0
hypothesis_count: int = 0
class HuntListResponse(BaseModel):
hunts: list[HuntResponse]
total: int
class HuntProgressResponse(BaseModel):
hunt_id: str
status: str
progress_percent: float
dataset_total: int
dataset_completed: int
dataset_processing: int
dataset_errors: int
active_jobs: int
queued_jobs: int
network_status: str
stages: dict
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
hunt = Hunt(name=body.name, description=body.description)
db.add(hunt)
await db.flush()
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
)
@router.get("", response_model=HuntListResponse, summary="List hunts")
async def list_hunts(
status: str | None = Query(None),
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
if status:
stmt = stmt.where(Hunt.status == status)
stmt = stmt.limit(limit).offset(offset)
result = await db.execute(stmt)
hunts = result.scalars().all()
count_stmt = select(func.count(Hunt.id))
if status:
count_stmt = count_stmt.where(Hunt.status == status)
total = (await db.execute(count_stmt)).scalar_one()
return HuntListResponse(
hunts=[
HuntResponse(
id=h.id,
name=h.name,
description=h.description,
status=h.status,
owner_id=h.owner_id,
created_at=h.created_at.isoformat(),
updated_at=h.updated_at.isoformat(),
dataset_count=len(h.datasets) if h.datasets else 0,
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
)
for h in hunts
],
total=total,
)
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
)
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
hunt = await db.get(Hunt, hunt_id)
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
ds_rows = await db.execute(
select(Dataset.id, Dataset.processing_status)
.where(Dataset.hunt_id == hunt_id)
)
datasets = ds_rows.all()
dataset_ids = {row[0] for row in datasets}
dataset_total = len(datasets)
dataset_completed = sum(1 for _, st in datasets if st == "completed")
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
if inventory_cache.get(hunt_id) is not None:
network_status = "ready"
network_ratio = 1.0
elif inventory_cache.is_building(hunt_id):
network_status = "building"
network_ratio = 0.5
else:
network_status = "none"
network_ratio = 0.0
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
status = "ready"
if dataset_total == 0:
status = "idle"
elif progress_percent < 100:
status = "processing"
stages = {
"datasets": {
"total": dataset_total,
"completed": dataset_completed,
"processing": dataset_processing,
"errors": dataset_errors,
"percent": round(dataset_ratio * 100.0, 1),
},
"network": {
"status": network_status,
"percent": round(network_ratio * 100.0, 1),
},
"jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
},
}
return HuntProgressResponse(
hunt_id=hunt_id,
status=status,
progress_percent=progress_percent,
dataset_total=dataset_total,
dataset_completed=dataset_completed,
dataset_processing=dataset_processing,
dataset_errors=dataset_errors,
active_jobs=active_jobs,
queued_jobs=queued_jobs,
network_status=network_status,
stages=stages,
)
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
async def update_hunt(
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
if body.name is not None:
hunt.name = body.name
if body.description is not None:
hunt.description = body.description
if body.status is not None:
hunt.status = body.status
await db.flush()
return HuntResponse(
id=hunt.id,
name=hunt.name,
description=hunt.description,
status=hunt.status,
owner_id=hunt.owner_id,
created_at=hunt.created_at.isoformat(),
updated_at=hunt.updated_at.isoformat(),
)
@router.delete("/{hunt_id}", summary="Delete a hunt")
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
await db.delete(hunt)
return {"message": "Hunt deleted", "id": hunt_id}
'''
p.write_text(new,encoding='utf-8')
print('updated hunts.py')

View File

@@ -0,0 +1,102 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/hunts.py')
t=p.read_text(encoding='utf-8')
if 'ProcessingTask' not in t:
t=t.replace('from app.db.models import Hunt, Dataset','from app.db.models import Hunt, Dataset, ProcessingTask')
old=''' jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs = sum(1 for j in relevant_jobs if j.get("status") == "queued")
if inventory_cache.get(hunt_id) is not None:
'''
new=''' jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
task_rows = await db.execute(
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
.where(ProcessingTask.hunt_id == hunt_id)
)
tasks = task_rows.all()
task_total = len(tasks)
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
task_running = sum(1 for _, st, _ in tasks if st == "running")
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
task_ratio = (task_done / task_total) if task_total > 0 else None
active_jobs = max(active_jobs_mem, task_running)
queued_jobs = max(queued_jobs_mem, task_queued)
stage_rollup: dict[str, dict] = {}
for stage, status, progress in tasks:
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
bucket["total"] += 1
if status in ("completed", "failed", "cancelled"):
bucket["done"] += 1
elif status == "running":
bucket["running"] += 1
elif status == "queued":
bucket["queued"] += 1
bucket["progress_sum"] += float(progress or 0.0)
for stage_name, bucket in stage_rollup.items():
total = max(1, bucket["total"])
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
if inventory_cache.get(hunt_id) is not None:
'''
if old not in t:
raise SystemExit('job block not found')
t=t.replace(old,new)
old2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
'''
new2=''' dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
if task_ratio is None:
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
else:
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
'''
if old2 not in t:
raise SystemExit('ratio block not found')
t=t.replace(old2,new2)
old3=''' "jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
},
}
'''
new3=''' "jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
"task_total": task_total,
"task_done": task_done,
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
},
"task_stages": stage_rollup,
}
'''
if old3 not in t:
raise SystemExit('stages jobs block not found')
t=t.replace(old3,new3)
p.write_text(t,encoding='utf-8')
print('updated hunt progress to merge persistent processing tasks')

46
_edit_job_queue.py Normal file
View File

@@ -0,0 +1,46 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
old='''async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
'''
new='''async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner, keyword_scan_cache
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
# Cache dataset-only result for fast API reuse
if dataset_id:
keyword_scan_cache.put(dataset_id, result)
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
'''
if old not in t:
raise SystemExit('target block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated job_queue keyword scan handler')

View File

@@ -0,0 +1,13 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
marker='''def register_all_handlers():
"""Register all job handlers and completion callbacks."""
'''
ins='''\n\nasync def reconcile_stale_processing_tasks() -> int:\n """Mark queued/running processing tasks from prior runs as failed."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n now = datetime.now(timezone.utc)\n async with async_session_factory() as db:\n result = await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.status.in_([\"queued\", \"running\"]))\n .values(\n status=\"failed\",\n error=\"Recovered after service restart before task completion\",\n message=\"Recovered stale task after restart\",\n completed_at=now,\n )\n )\n await db.commit()\n updated = int(result.rowcount or 0)\n\n if updated:\n logger.warning(\n \"Reconciled %d stale processing tasks (queued/running -> failed) during startup\",\n updated,\n )\n return updated\n except Exception as e:\n logger.warning(f\"Failed to reconcile stale processing tasks: {e}\")\n return 0\n\n\n'''
if ins.strip() not in t:
if marker not in t:
raise SystemExit('register marker not found')
t=t.replace(marker,ins+marker)
p.write_text(t,encoding='utf-8')
print('added reconcile_stale_processing_tasks to job_queue')

64
_edit_jobqueue_sync.py Normal file
View File

@@ -0,0 +1,64 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
ins='''\n\nasync def _sync_processing_task(job: Job):\n """Persist latest job state into processing_tasks (if linked by job_id)."""\n from datetime import datetime, timezone\n from sqlalchemy import update\n\n try:\n from app.db import async_session_factory\n from app.db.models import ProcessingTask\n\n values = {\n "status": job.status.value,\n "progress": float(job.progress),\n "message": job.message,\n "error": job.error,\n }\n if job.started_at:\n values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)\n if job.completed_at:\n values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)\n\n async with async_session_factory() as db:\n await db.execute(\n update(ProcessingTask)\n .where(ProcessingTask.job_id == job.id)\n .values(**values)\n )\n await db.commit()\n except Exception as e:\n logger.warning(f"Failed to sync processing task for job {job.id}: {e}")\n'''
marker='\n\n# -- Singleton + job handlers --\n'
if ins.strip() not in t:
t=t.replace(marker, ins+marker)
old=''' job.status = JobStatus.RUNNING
job.started_at = time.time()
job.message = "Running..."
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
'''
new=''' job.status = JobStatus.RUNNING
job.started_at = time.time()
if job.progress <= 0:
job.progress = 5.0
job.message = "Running..."
await _sync_processing_task(job)
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
'''
if old not in t:
raise SystemExit('worker running block not found')
t=t.replace(old,new)
old2=''' job.completed_at = time.time()
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
# Fire completion callbacks
'''
new2=''' job.completed_at = time.time()
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
if job.is_cancelled and not job.completed_at:
job.completed_at = time.time()
await _sync_processing_task(job)
# Fire completion callbacks
'''
if old2 not in t:
raise SystemExit('worker completion block not found')
t=t.replace(old2,new2)
p.write_text(t, encoding='utf-8')
print('updated job_queue persistent task syncing')

View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/job_queue.py')
t=p.read_text(encoding='utf-8')
old=''' if hunt_id:
job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
'''
new=''' if hunt_id:
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
try:
from sqlalchemy import select
from app.db.models import ProcessingTask
async with async_session_factory() as db:
existing = await db.execute(
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
)
if existing.first() is None:
db.add(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset_id,
job_id=hp_job.id,
stage="host_profile",
status="queued",
progress=0.0,
message="Queued",
))
await db.commit()
except Exception as persist_err:
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
'''
if old not in t:
raise SystemExit('triage chain block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated triage chain to persist host_profile task row')

321
_edit_keywords.py Normal file
View File

@@ -0,0 +1,321 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
new_text='''"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import KeywordTheme, Keyword
from app.services.scanner import KeywordScanner, keyword_scan_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
class ThemeCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=128)
color: str = Field(default="#9e9e9e", max_length=16)
enabled: bool = True
class ThemeUpdate(BaseModel):
name: str | None = None
color: str | None = None
enabled: bool | None = None
class KeywordOut(BaseModel):
id: int
theme_id: str
value: str
is_regex: bool
created_at: str
class ThemeOut(BaseModel):
id: str
name: str
color: str
enabled: bool
is_builtin: bool
created_at: str
keyword_count: int
keywords: list[KeywordOut]
class ThemeListResponse(BaseModel):
themes: list[ThemeOut]
total: int
class KeywordCreate(BaseModel):
value: str = Field(..., min_length=1, max_length=256)
is_regex: bool = False
class KeywordBulkCreate(BaseModel):
values: list[str] = Field(..., min_items=1)
is_regex: bool = False
class ScanRequest(BaseModel):
dataset_ids: list[str] | None = None
theme_ids: list[str] | None = None
scan_hunts: bool = False
scan_annotations: bool = False
scan_messages: bool = False
prefer_cache: bool = True
force_rescan: bool = False
class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
class ScanResponse(BaseModel):
total_hits: int
hits: list[ScanHit]
themes_scanned: int
keywords_scanned: int
rows_scanned: int
cache_used: bool = False
cache_status: str = "miss"
cached_at: str | None = None
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
return ThemeOut(
id=t.id,
name=t.name,
color=t.color,
enabled=t.enabled,
is_builtin=t.is_builtin,
created_at=t.created_at.isoformat(),
keyword_count=len(t.keywords),
keywords=[
KeywordOut(
id=k.id,
theme_id=k.theme_id,
value=k.value,
is_regex=k.is_regex,
created_at=k.created_at.isoformat(),
)
for k in t.keywords
],
)
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
hits: list[dict] = []
total_rows = 0
cached_at: str | None = None
for entry in entries:
result = entry["result"]
total_rows += int(result.get("rows_scanned", 0) or 0)
if entry.get("built_at"):
if not cached_at or entry["built_at"] > cached_at:
cached_at = entry["built_at"]
for h in result.get("hits", []):
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
continue
hits.append(h)
return {
"total_hits": len(hits),
"hits": hits,
"rows_scanned": total_rows,
"cached_at": cached_at,
}
@router.get("/themes", response_model=ThemeListResponse)
async def list_themes(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
themes = result.scalars().all()
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
@router.post("/themes", response_model=ThemeOut, status_code=201)
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
if exists:
raise HTTPException(409, f"Theme '{body.name}' already exists")
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
db.add(theme)
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.put("/themes/{theme_id}", response_model=ThemeOut)
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
if body.name is not None:
dup = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
)
if dup:
raise HTTPException(409, f"Theme '{body.name}' already exists")
theme.name = body.name
if body.color is not None:
theme.color = body.color
if body.enabled is not None:
theme.enabled = body.enabled
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.delete("/themes/{theme_id}", status_code=204)
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
await db.delete(theme)
keyword_scan_cache.clear()
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
db.add(kw)
await db.flush()
await db.refresh(kw)
keyword_scan_cache.clear()
return KeywordOut(
id=kw.id, theme_id=kw.theme_id, value=kw.value,
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
)
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
added = 0
for val in body.values:
val = val.strip()
if not val:
continue
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
added += 1
await db.flush()
keyword_scan_cache.clear()
return {"added": added, "theme_id": theme_id}
@router.delete("/keywords/{keyword_id}", status_code=204)
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
kw = await db.get(Keyword, keyword_id)
if not kw:
raise HTTPException(404, "Keyword not found")
await db.delete(kw)
keyword_scan_cache.clear()
@router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
scanner = KeywordScanner(db)
can_use_cache = (
body.prefer_cache
and not body.force_rescan
and bool(body.dataset_ids)
and not body.scan_hunts
and not body.scan_annotations
and not body.scan_messages
)
if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
result = await scanner.scan(
dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids,
scan_hunts=body.scan_hunts,
scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages,
)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
@router.get("/scan/quick", response_model=ScanResponse)
async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db),
):
entry = keyword_scan_cache.get(dataset_id)
if entry is not None:
result = entry.result
return {
**result,
"cache_used": True,
"cache_status": "hit",
"cached_at": entry.built_at,
}
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
keyword_scan_cache.put(dataset_id, result)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
'''
p.write_text(new_text,encoding='utf-8')
print('updated keywords.py')

31
_edit_main_reconcile.py Normal file
View File

@@ -0,0 +1,31 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/main.py')
t=p.read_text(encoding='utf-8')
old=''' # Start job queue
from app.services.job_queue import job_queue, register_all_handlers, JobType
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
'''
new=''' # Start job queue
from app.services.job_queue import (
job_queue,
register_all_handlers,
reconcile_stale_processing_tasks,
JobType,
)
if settings.STARTUP_RECONCILE_STALE_TASKS:
reconciled = await reconcile_stale_processing_tasks()
if reconciled:
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
'''
if old not in t:
raise SystemExit('startup queue block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('wired startup reconciliation in main lifespan')

View File

@@ -0,0 +1,45 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/db/models.py')
t=p.read_text(encoding='utf-8')
if 'class ProcessingTask(Base):' in t:
print('processing task model already exists')
raise SystemExit(0)
insert='''
# -- Persistent Processing Tasks (Phase 2) ---
class ProcessingTask(Base):
__tablename__ = "processing_tasks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
)
dataset_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
)
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
progress: Mapped[float] = mapped_column(Float, default=0.0)
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
)
'''
# insert before Playbook section
marker='\n\n# -- Playbook / Investigation Templates (Feature 3) ---\n'
if marker not in t:
raise SystemExit('marker not found for insertion')
t=t.replace(marker, insert+marker)
p.write_text(t,encoding='utf-8')
print('added ProcessingTask model')

59
_edit_networkmap_hit.py Normal file
View File

@@ -0,0 +1,59 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
insert='''
function isPointOnNodeLabel(node: GNode, wx: number, wy: number, vp: Viewport): boolean {
const fontSize = Math.max(9, Math.round(12 / vp.scale));
const approxCharW = Math.max(5, fontSize * 0.58);
const line1 = node.label || '';
const line2 = node.meta.ips.length > 0 ? node.meta.ips[0] : '';
const tw = Math.max(line1.length * approxCharW, line2 ? line2.length * approxCharW : 0);
const px = 5, py = 2;
const totalH = line2 ? fontSize * 2 + py * 2 : fontSize + py * 2;
const lx = node.x, ly = node.y - node.radius - 6;
const rx = lx - tw / 2 - px;
const ry = ly - totalH;
const rw = tw + px * 2;
const rh = totalH;
return wx >= rx && wx <= (rx + rw) && wy >= ry && wy <= (ry + rh);
}
'''
if 'function isPointOnNodeLabel' not in t:
t=t.replace('// == Hit-test =============================================================\n', '// == Hit-test =============================================================\n'+insert)
old='''function hitTest(
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): GNode | null {
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
for (const n of graph.nodes) {
const dx = n.x - wx, dy = n.y - wy;
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
}
return null;
}
'''
new='''function hitTest(
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
): GNode | null {
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
// Node-circle hit has priority
for (const n of graph.nodes) {
const dx = n.x - wx, dy = n.y - wy;
if (dx * dx + dy * dy < (n.radius + 5) ** 2) return n;
}
// Then label hit (so clicking text works too)
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
return null;
}
'''
if old not in t:
raise SystemExit('hitTest block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated NetworkMap hit-test for labels')

272
_edit_scanner.py Normal file
View File

@@ -0,0 +1,272 @@
from pathlib import Path
p = Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
text = p.read_text(encoding='utf-8')
new_text = '''"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
messages for keyword matches.
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
across SQLite / PostgreSQL and to provide per-cell match context.
"""
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import (
KeywordTheme,
DatasetRow,
Dataset,
Hunt,
Annotation,
Message,
)
logger = logging.getLogger(__name__)
BATCH_SIZE = 200
@dataclass
class ScanHit:
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
@dataclass
class ScanResult:
total_hits: int = 0
hits: list[ScanHit] = field(default_factory=list)
themes_scanned: int = 0
keywords_scanned: int = 0
rows_scanned: int = 0
@dataclass
class KeywordScanCacheEntry:
dataset_id: str
result: dict
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
class KeywordScanCache:
"""In-memory per-dataset cache for dataset-only keyword scans.
This enables fast-path reads when users run AUP scans against datasets that
were already scanned during upload pipeline processing.
"""
def __init__(self):
self._entries: dict[str, KeywordScanCacheEntry] = {}
def put(self, dataset_id: str, result: dict):
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
return self._entries.get(dataset_id)
def invalidate_dataset(self, dataset_id: str):
self._entries.pop(dataset_id, None)
def clear(self):
self._entries.clear()
keyword_scan_cache = KeywordScanCache()
class KeywordScanner:
"""Scans multiple data sources for keyword/regex matches."""
def __init__(self, db: AsyncSession):
self.db = db
# Public API
async def scan(
self,
dataset_ids: list[str] | None = None,
theme_ids: list[str] | None = None,
scan_hunts: bool = False,
scan_annotations: bool = False,
scan_messages: bool = False,
) -> dict:
"""Run a full AUP scan and return dict matching ScanResponse."""
# Load themes + keywords
themes = await self._load_themes(theme_ids)
if not themes:
return ScanResult().__dict__
# Pre-compile patterns per theme
patterns = self._compile_patterns(themes)
result = ScanResult(
themes_scanned=len(themes),
keywords_scanned=sum(len(kws) for kws in patterns.values()),
)
# Scan dataset rows
await self._scan_datasets(patterns, result, dataset_ids)
# Scan hunts
if scan_hunts:
await self._scan_hunts(patterns, result)
# Scan annotations
if scan_annotations:
await self._scan_annotations(patterns, result)
# Scan messages
if scan_messages:
await self._scan_messages(patterns, result)
result.total_hits = len(result.hits)
return {
"total_hits": result.total_hits,
"hits": [h.__dict__ for h in result.hits],
"themes_scanned": result.themes_scanned,
"keywords_scanned": result.keywords_scanned,
"rows_scanned": result.rows_scanned,
}
# Internal
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
if theme_ids:
q = q.where(KeywordTheme.id.in_(theme_ids))
result = await self.db.execute(q)
return list(result.scalars().all())
def _compile_patterns(
self, themes: list[KeywordTheme]
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
for theme in themes:
key = (theme.id, theme.name, theme.color)
compiled = []
for kw in theme.keywords:
try:
if kw.is_regex:
pat = re.compile(kw.value, re.IGNORECASE)
else:
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
compiled.append((kw.value, pat))
except re.error:
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
kw.value, theme.name)
patterns[key] = compiled
return patterns
def _match_text(
self,
text: str,
patterns: dict,
source_type: str,
source_id: str | int,
field_name: str,
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
) -> None:
"""Check text against all compiled patterns, append hits."""
if not text:
return
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
for kw_value, pat in keyword_patterns:
if pat.search(text):
matched_preview = text[:200] + ("" if len(text) > 200 else "")
hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
keyword=kw_value,
source_type=source_type,
source_id=source_id,
field=field_name,
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
))
async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches."""
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
ds_result = await self.db.execute(ds_q)
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
if not ds_map:
return
offset = 0
row_q_base = select(DatasetRow).where(
DatasetRow.dataset_id.in_(list(ds_map.keys()))
).order_by(DatasetRow.id)
while True:
rows_result = await self.db.execute(
row_q_base.offset(offset).limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text, patterns, "dataset_row", row.id,
col_name, result.hits,
row_index=row.row_index,
dataset_name=ds_map.get(row.dataset_id),
)
offset += BATCH_SIZE
import asyncio
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
"""Scan hunt names and descriptions."""
hunts_result = await self.db.execute(select(Hunt))
for hunt in hunts_result.scalars().all():
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
if hunt.description:
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
"""Scan annotation text."""
ann_result = await self.db.execute(select(Annotation))
for ann in ann_result.scalars().all():
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
"""Scan conversation messages (user messages only)."""
msg_result = await self.db.execute(
select(Message).where(Message.role == "user")
)
for msg in msg_result.scalars().all():
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
'''
p.write_text(new_text, encoding='utf-8')
print('updated scanner.py')

31
_edit_test_api.py Normal file
View File

@@ -0,0 +1,31 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_api.py')
t=p.read_text(encoding='utf-8')
insert='''
async def test_hunt_progress(self, client):
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
hunt_id = create.json()["id"]
# attach one dataset so progress has scope
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
res = await client.get(f"/api/hunts/{hunt_id}/progress")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert "progress_percent" in body
assert "dataset_total" in body
assert "network_status" in body
'''
needle=''' async def test_get_nonexistent_hunt(self, client):
resp = await client.get("/api/hunts/nonexistent-id")
assert resp.status_code == 404
'''
if needle in t and 'test_hunt_progress' not in t:
t=t.replace(needle, needle+'\n'+insert)
p.write_text(t,encoding='utf-8')
print('updated test_api.py')

32
_edit_test_keywords.py Normal file
View File

@@ -0,0 +1,32 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/tests/test_keywords.py')
t=p.read_text(encoding='utf-8')
add='''
@pytest.mark.asyncio
async def test_quick_scan_cache_hit(client: AsyncClient):
"""Second quick scan should return cache hit metadata."""
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
tid = theme_res.json()["id"]
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
upload = await client.post("/api/datasets/upload", files=files)
ds_id = upload.json()["id"]
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert first.status_code == 200
assert first.json().get("cache_status") in ("miss", "hit")
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert second.status_code == 200
body = second.json()
assert body.get("cache_used") is True
assert body.get("cache_status") == "hit"
'''
if 'test_quick_scan_cache_hit' not in t:
t=t + add
p.write_text(t,encoding='utf-8')
print('updated test_keywords.py')

26
_edit_upload.py Normal file
View File

@@ -0,0 +1,26 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
t=p.read_text(encoding='utf-8')
# import useEffect
t=t.replace("import React, { useState, useCallback, useRef } from 'react';","import React, { useState, useCallback, useRef, useEffect } from 'react';")
# import HuntProgress type
t=t.replace("import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';","import { datasets, hunts, type UploadResult, type Hunt, type HuntProgress } from '../api/client';")
# add state
if 'const [huntProgress, setHuntProgress]' not in t:
t=t.replace(" const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');"," const [huntList, setHuntList] = useState<Hunt[]>([]);\n const [huntId, setHuntId] = useState('');\n const [huntProgress, setHuntProgress] = useState<HuntProgress | null>(null);")
# add polling effect after hunts list effect
marker=" React.useEffect(() => {\n hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});\n }, []);\n"
if marker in t and 'setInterval' not in t.split(marker,1)[1][:500]:
add='''\n useEffect(() => {\n let timer: any = null;\n let cancelled = false;\n\n const pull = async () => {\n if (!huntId) {\n if (!cancelled) setHuntProgress(null);\n return;\n }\n try {\n const p = await hunts.progress(huntId);\n if (!cancelled) setHuntProgress(p);\n } catch {\n if (!cancelled) setHuntProgress(null);\n }\n };\n\n pull();\n if (huntId) timer = setInterval(pull, 2000);\n return () => { cancelled = true; if (timer) clearInterval(timer); };\n }, [huntId, jobs.length]);\n'''
t=t.replace(marker, marker+add)
# insert master progress UI after overall summary
insert_after=''' {overallTotal > 0 && (\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mt: 2 }}>\n <Typography variant="body2" color="text.secondary">\n {overallDone + overallErr} / {overallTotal} files processed\n {overallErr > 0 && ` ({overallErr} failed)`}\n </Typography>\n <Box sx={{ flexGrow: 1 }} />\n {overallDone + overallErr === overallTotal && overallTotal > 0 && (\n <Tooltip title="Clear completed">\n <IconButton size="small" onClick={clearCompleted}><ClearIcon fontSize="small" /></IconButton>\n </Tooltip>\n )}\n </Stack>\n )}\n'''
add_block='''\n {huntId && huntProgress && (\n <Paper sx={{ p: 1.5, mt: 1.5 }}>\n <Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>\n <Typography variant="body2" sx={{ fontWeight: 600 }}>\n Master Processing Progress\n </Typography>\n <Chip\n size="small"\n label={huntProgress.status.toUpperCase()}\n color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}\n variant="outlined"\n />\n <Box sx={{ flexGrow: 1 }} />\n <Typography variant="caption" color="text.secondary">\n {huntProgress.progress_percent.toFixed(1)}%\n </Typography>\n </Stack>\n <LinearProgress\n variant="determinate"\n value={Math.max(0, Math.min(100, huntProgress.progress_percent))}\n sx={{ height: 8, borderRadius: 4 }}\n />\n <Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>\n <Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />\n <Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />\n <Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />\n <Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />\n </Stack>\n </Paper>\n )}\n'''
if insert_after in t:
t=t.replace(insert_after, insert_after+add_block)
else:
print('warning: summary block not found')
p.write_text(t,encoding='utf-8')
print('updated FileUpload.tsx')

42
_edit_upload2.py Normal file
View File

@@ -0,0 +1,42 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/FileUpload.tsx')
t=p.read_text(encoding='utf-8')
marker=''' {/* Per-file progress list */}
'''
add=''' {huntId && huntProgress && (
<Paper sx={{ p: 1.5, mt: 1.5 }}>
<Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 0.8 }}>
<Typography variant="body2" sx={{ fontWeight: 600 }}>
Master Processing Progress
</Typography>
<Chip
size="small"
label={huntProgress.status.toUpperCase()}
color={huntProgress.status === 'ready' ? 'success' : huntProgress.status === 'processing' ? 'warning' : 'default'}
variant="outlined"
/>
<Box sx={{ flexGrow: 1 }} />
<Typography variant="caption" color="text.secondary">
{huntProgress.progress_percent.toFixed(1)}%
</Typography>
</Stack>
<LinearProgress
variant="determinate"
value={Math.max(0, Math.min(100, huntProgress.progress_percent))}
sx={{ height: 8, borderRadius: 4 }}
/>
<Stack direction="row" spacing={1} sx={{ mt: 1 }} flexWrap="wrap" useFlexGap>
<Chip size="small" label={`Datasets ${huntProgress.dataset_completed}/${huntProgress.dataset_total}`} variant="outlined" />
<Chip size="small" label={`Active jobs ${huntProgress.active_jobs}`} variant="outlined" />
<Chip size="small" label={`Queued jobs ${huntProgress.queued_jobs}`} variant="outlined" />
<Chip size="small" label={`Network ${huntProgress.network_status}`} variant="outlined" />
</Stack>
</Paper>
)}
'''
if marker not in t:
raise SystemExit('marker not found')
t=t.replace(marker, add+marker)
p.write_text(t,encoding='utf-8')
print('inserted master progress block')

View File

@@ -0,0 +1,55 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from sqlalchemy.ext.asyncio import AsyncSession\n','from sqlalchemy.ext.asyncio import AsyncSession\n\nfrom app.config import settings\n')
old=''' import asyncio
for ds_id, ds_name in ds_map.items():
last_id = 0
while True:
'''
new=''' import asyncio
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
budget_reached = False
for ds_id, ds_name in ds_map.items():
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
last_id = 0
while True:
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
'''
if old not in t:
raise SystemExit('scanner loop block not found')
t=t.replace(old,new)
old2=''' if len(rows) < BATCH_SIZE:
break
'''
new2=''' if len(rows) < BATCH_SIZE:
break
if budget_reached:
break
if budget_reached:
logger.warning(
"AUP scan row budget reached (%d rows). Returning partial results.",
result.rows_scanned,
)
'''
if old2 not in t:
raise SystemExit('scanner break block not found')
t=t.replace(old2,new2,1)
p.write_text(t,encoding='utf-8')
print('added scanner global row budget enforcement')

12
_fix_aup_dep.py Normal file
View File

@@ -0,0 +1,12 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
old=''' }, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
'''
new=''' }, [selectedHuntId, selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
'''
if old not in t:
raise SystemExit('runScan deps block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('fixed AUPScanner runScan dependency list')

7
_fix_import_datasets.py Normal file
View File

@@ -0,0 +1,7 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/datasets.py')
t=p.read_text(encoding='utf-8')
if 'from app.db.models import ProcessingTask' not in t:
t=t.replace('from app.db import get_db\n', 'from app.db import get_db\nfrom app.db.models import ProcessingTask\n')
p.write_text(t, encoding='utf-8')
print('added ProcessingTask import')

View File

@@ -0,0 +1,25 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
raise HTTPException(400, "Select at least one dataset or enable additional sources (hunts/annotations/messages)")
'''
new=''' if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
return {
"total_hits": 0,
"hits": [],
"themes_scanned": 0,
"keywords_scanned": 0,
"rows_scanned": 0,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
'''
if old not in t:
raise SystemExit('scope guard block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('adjusted empty scan guard to return fast empty result (200)')

View File

@@ -0,0 +1,47 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# Add label selector in toolbar before refresh button
insert_after=""" <TextField
size=\"small\"
placeholder=\"Search hosts, IPs, users\\u2026\"
value={search}
onChange={e => setSearch(e.target.value)}
sx={{ width: 220, '& .MuiInputBase-input': { py: 0.8 } }}
slotProps={{
input: {
startAdornment: <SearchIcon sx={{ mr: 0.5, fontSize: 18, color: 'text.secondary' }} />,
},
}}
/>
"""
label_ctrl="""
<FormControl size=\"small\" sx={{ minWidth: 150 }}>
<InputLabel id=\"label-mode-selector\">Labels</InputLabel>
<Select
labelId=\"label-mode-selector\"
value={labelMode}
label=\"Labels\"
onChange={e => setLabelMode(e.target.value as LabelMode)}
sx={{ '& .MuiSelect-select': { py: 0.8 } }}
>
<MenuItem value=\"none\">None</MenuItem>
<MenuItem value=\"highlight\">Selected/Search</MenuItem>
<MenuItem value=\"all\">All</MenuItem>
</Select>
</FormControl>
"""
if 'label-mode-selector' not in t:
if insert_after not in t:
raise SystemExit('search block not found for label selector insertion')
t=t.replace(insert_after, insert_after+label_ctrl)
# Fix useCallback dependency for startAnimLoop
old=' }, [canvasSize]);'
new=' }, [canvasSize, labelMode]);'
if old in t:
t=t.replace(old,new,1)
p.write_text(t,encoding='utf-8')
print('inserted label selector UI and fixed callback dependency')

View File

@@ -0,0 +1,10 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
count=t.count('}, [canvasSize]);')
if count:
t=t.replace('}, [canvasSize]);','}, [canvasSize, labelMode]);')
# In case formatter created spaced variant
t=t.replace('}, [canvasSize ]);','}, [canvasSize, labelMode]);')
p.write_text(t,encoding='utf-8')
print('patched remaining canvasSize callback deps:', count)

71
_harden_aup_scope_ui.py Normal file
View File

@@ -0,0 +1,71 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/AUPScanner.tsx')
t=p.read_text(encoding='utf-8')
# Auto-select first hunt with datasets after load
old=''' const [tRes, hRes] = await Promise.all([
keywords.listThemes(),
hunts.list(0, 200),
]);
setThemes(tRes.themes);
setHuntList(hRes.hunts);
'''
new=''' const [tRes, hRes] = await Promise.all([
keywords.listThemes(),
hunts.list(0, 200),
]);
setThemes(tRes.themes);
setHuntList(hRes.hunts);
if (!selectedHuntId && hRes.hunts.length > 0) {
const best = hRes.hunts.find(h => h.dataset_count > 0) || hRes.hunts[0];
setSelectedHuntId(best.id);
}
'''
if old not in t:
raise SystemExit('loadData block not found')
t=t.replace(old,new)
# Guard runScan
old2=''' const runScan = useCallback(async () => {
setScanning(true);
setScanResult(null);
try {
'''
new2=''' const runScan = useCallback(async () => {
if (!selectedHuntId) {
enqueueSnackbar('Please select a hunt before running AUP scan', { variant: 'warning' });
return;
}
if (selectedDs.size === 0) {
enqueueSnackbar('No datasets selected for this hunt', { variant: 'warning' });
return;
}
setScanning(true);
setScanResult(null);
try {
'''
if old2 not in t:
raise SystemExit('runScan header not found')
t=t.replace(old2,new2)
# update loadData deps
old3=''' }, [enqueueSnackbar]);
'''
new3=''' }, [enqueueSnackbar, selectedHuntId]);
'''
if old3 not in t:
raise SystemExit('loadData deps not found')
t=t.replace(old3,new3,1)
# disable button if no hunt or no datasets
old4=''' onClick={runScan} disabled={scanning}
'''
new4=''' onClick={runScan} disabled={scanning || !selectedHuntId || selectedDs.size === 0}
'''
if old4 not in t:
raise SystemExit('scan button props not found')
t=t.replace(old4,new4)
p.write_text(t,encoding='utf-8')
print('hardened AUPScanner to require explicit hunt/dataset scope')

View File

@@ -0,0 +1,84 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/api/routes/keywords.py')
t=p.read_text(encoding='utf-8')
old=''' if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
'''
new=''' if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
if missing:
missing_entries: list[dict] = []
for dataset_id in missing:
partial = await scanner.scan(dataset_ids=[dataset_id], theme_ids=body.theme_ids)
keyword_scan_cache.put(dataset_id, partial)
missing_entries.append({"result": partial, "built_at": None})
merged = _merge_cached_results(
cached_entries + missing_entries,
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
'''
if old not in t:
raise SystemExit('cache block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated keyword /scan to use partial cache + scan missing datasets only')

View File

@@ -0,0 +1,61 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/scanner.py')
t=p.read_text(encoding='utf-8')
start=t.index(' async def _scan_datasets(')
end=t.index(' async def _scan_hunts', start)
new_func=''' async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
ds_result = await self.db.execute(ds_q)
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
if not ds_map:
return
import asyncio
for ds_id, ds_name in ds_map.items():
last_id = 0
while True:
rows_result = await self.db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds_id)
.where(DatasetRow.id > last_id)
.order_by(DatasetRow.id)
.limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
)
last_id = rows[-1].id
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
'''
out=t[:start]+new_func+t[end:]
p.write_text(out,encoding='utf-8')
print('optimized scanner _scan_datasets to keyset pagination')

36
_patch_inventory_stats.py Normal file
View File

@@ -0,0 +1,36 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
old=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
},
}
'''
new=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
},
}
'''
if old not in t:
raise SystemExit('return block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('patched inventory stats metadata')

View File

@@ -0,0 +1,10 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
needle=' "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n'
if '"row_budget_per_dataset"' not in t:
if needle not in t:
raise SystemExit('needle not found')
t=t.replace(needle, needle + ' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,\n')
p.write_text(t,encoding='utf-8')
print('inserted inventory budget stats lines')

14
_patch_network_sleep.py Normal file
View File

@@ -0,0 +1,14 @@
from pathlib import Path
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
text = p.read_text(encoding="utf-8")
anchor = " useEffect(() => { canvasSizeRef.current = canvasSize; }, [canvasSize]);\n"
insert = anchor + "\n const sleep = (ms: number) => new Promise<void>(resolve => setTimeout(resolve, ms));\n"
if "const sleep = (ms: number)" not in text and anchor in text:
text = text.replace(anchor, insert)
text = text.replace("await new Promise(r => setTimeout(r, delayMs + jitter));", "await sleep(delayMs + jitter);")
p.write_text(text, encoding="utf-8")
print("Patched sleep helper + polling awaits")

37
_patch_network_wait.py Normal file
View File

@@ -0,0 +1,37 @@
from pathlib import Path
import re
p = Path(r"d:\Projects\Dev\ThreatHunt\frontend\src\components\NetworkMap.tsx")
text = p.read_text(encoding="utf-8")
pattern = re.compile(r"const waitUntilReady = async \(\): Promise<boolean> => \{[\s\S]*?\n\s*\};", re.M)
replacement = '''const waitUntilReady = async (): Promise<boolean> => {
// Poll inventory-status with exponential backoff until 'ready' (or cancelled)
setProgress('Host inventory is being prepared in the background');
setLoading(true);
let delayMs = 1500;
const startedAt = Date.now();
for (;;) {
const jitter = Math.floor(Math.random() * 250);
await new Promise(r => setTimeout(r, delayMs + jitter));
if (cancelled) return false;
try {
const st = await network.inventoryStatus(selectedHuntId);
if (cancelled) return false;
if (st.status === 'ready') return true;
if (Date.now() - startedAt > 5 * 60 * 1000) {
setError('Host inventory build timed out. Please retry.');
return false;
}
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
// still building or none (job may not have started yet) - keep polling
} catch {
if (cancelled) return false;
delayMs = Math.min(10000, Math.floor(delayMs * 1.5));
}
}
};'''
new_text, n = pattern.subn(replacement, text, count=1)
if n != 1:
raise SystemExit(f"Failed to patch waitUntilReady, matches={n}")
p.write_text(new_text, encoding="utf-8")
print("Patched waitUntilReady")

View File

@@ -0,0 +1,26 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=p.read_text(encoding='utf-8')
old=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=25000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
'''
new=''' NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=5000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
default=120000,
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
default=120000,
description="Max unique connection tuples retained during host inventory build",
)
'''
if old not in t:
raise SystemExit('network inventory block not found')
t=t.replace(old,new)
p.write_text(t,encoding='utf-8')
print('updated network inventory budgets in config')

View File

@@ -0,0 +1,164 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
# insert budget vars near existing counters
old=''' connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
'''
new=''' connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
'''
if old not in t:
raise SystemExit('counter block not found')
t=t.replace(old,new)
# update batch size and sampled count increments + global budget checks
old2=''' batch_size = 10000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
'''
new2=''' batch_size = 5000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
'''
if old2 not in t:
raise SystemExit('batch block not found')
t=t.replace(old2,new2)
old3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
'''
new3=''' if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
'''
if old3 not in t:
raise SystemExit('row scan block not found')
t=t.replace(old3,new3)
# cap connection map growth
old4=''' for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
connections[(host_key, rip, rport)] += 1
'''
new4=''' for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
'''
if old4 not in t:
raise SystemExit('connection block not found')
t=t.replace(old4,new4)
# sampled_dataset counter
old5=''' if sampled_dataset:
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
'''
new5=''' if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
'''
if old5 not in t:
raise SystemExit('sampled block not found')
t=t.replace(old5,new5)
# break dataset loop if global budget reached
old6=''' if len(rows) < batch_size:
break
# Post-process hosts
'''
new6=''' if len(rows) < batch_size:
break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts
'''
if old6 not in t:
raise SystemExit('post-process boundary block not found')
t=t.replace(old6,new6)
# add stats
old7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0,
},
}
'''
new7=''' "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
},
}
'''
if old7 not in t:
raise SystemExit('stats block not found')
t=t.replace(old7,new7)
p.write_text(t,encoding='utf-8')
print('updated host inventory with global row and connection budgets')

View File

@@ -0,0 +1,39 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
repls={
"const LARGE_HUNT_SUBGRAPH_HOSTS = 350;":"const LARGE_HUNT_SUBGRAPH_HOSTS = 220;",
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;":"const LARGE_HUNT_SUBGRAPH_EDGES = 1200;",
"const RENDER_SIMPLIFY_NODE_THRESHOLD = 220;":"const RENDER_SIMPLIFY_NODE_THRESHOLD = 120;",
"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;":"const RENDER_SIMPLIFY_EDGE_THRESHOLD = 500;",
"const EDGE_DRAW_TARGET = 1000;":"const EDGE_DRAW_TARGET = 600;"
}
for a,b in repls.items():
if a not in t:
raise SystemExit(f'missing constant: {a}')
t=t.replace(a,b)
old=''' // Then label hit (so clicking text works too)
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
'''
new=''' // Then label hit (so clicking text works too on manageable graph sizes)
if (graph.nodes.length <= 220) {
for (const n of graph.nodes) {
if (isPointOnNodeLabel(n, wx, wy, vp)) return n;
}
}
'''
if old not in t:
raise SystemExit('label hit block not found')
t=t.replace(old,new)
old2='simulate(g, w / 2, h / 2, 60);'
if t.count(old2) < 2:
raise SystemExit('expected two simulate calls')
t=t.replace(old2,'simulate(g, w / 2, h / 2, 20);',1)
t=t.replace(old2,'simulate(g, w / 2, h / 2, 30);',1)
p.write_text(t,encoding='utf-8')
print('tightened network map rendering + load limits')

107
_perf_patch_backend.py Normal file
View File

@@ -0,0 +1,107 @@
from pathlib import Path
# config updates
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=cfg.read_text(encoding='utf-8')
anchor=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
'''
ins=''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=200000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
'''
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
if anchor not in t:
raise SystemExit('config network anchor not found')
t=t.replace(anchor,ins)
cfg.write_text(t,encoding='utf-8')
# host inventory updates
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from app.db.models import Dataset, DatasetRow\n', 'from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n', ' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
old=''' for ro in rows:
data = ro.data or {}
total_rows += 1
fqdn = ''
'''
new=''' for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
fqdn = ''
'''
if old not in t:
raise SystemExit('row loop anchor not found')
t=t.replace(old,new)
old2=''' last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
'''
new2=''' if sampled_dataset:
logger.info(
"Host inventory row budget reached for dataset %s (%d rows)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
'''
if old2 not in t:
raise SystemExit('batch loop end anchor not found')
t=t.replace(old2,new2)
old3=''' return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
},
}
'''
new3=''' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0
return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"sampled_mode": sampled,
},
}
'''
if old3 not in t:
raise SystemExit('return stats anchor not found')
t=t.replace(old3,new3)
p.write_text(t,encoding='utf-8')
print('patched config + host inventory row budget')

38
_perf_patch_backend2.py Normal file
View File

@@ -0,0 +1,38 @@
from pathlib import Path
cfg=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/config.py')
t=cfg.read_text(encoding='utf-8')
if 'NETWORK_INVENTORY_MAX_ROWS_PER_DATASET' not in t:
t=t.replace(
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
''',
''' NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=200000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
''')
cfg.write_text(t,encoding='utf-8')
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
if 'from app.config import settings' not in t:
t=t.replace('from app.db.models import Dataset, DatasetRow\n','from app.db.models import Dataset, DatasetRow\nfrom app.config import settings\n')
t=t.replace(' batch_size = 5000\n last_row_index = -1\n while True:\n',
' batch_size = 10000\n max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))\n rows_scanned_this_dataset = 0\n sampled_dataset = False\n last_row_index = -1\n while True:\n')
t=t.replace(' for ro in rows:\n data = ro.data or {}\n total_rows += 1\n\n',
' for ro in rows:\n if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:\n sampled_dataset = True\n break\n\n data = ro.data or {}\n total_rows += 1\n rows_scanned_this_dataset += 1\n\n')
t=t.replace(' last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n',
' if sampled_dataset:\n logger.info(\n "Host inventory row budget reached for dataset %s (%d rows)",\n ds.id,\n rows_scanned_this_dataset,\n )\n break\n\n last_row_index = rows[-1].row_index\n if len(rows) < batch_size:\n break\n')
t=t.replace(' return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n },\n }\n',
' sampled = settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0\n\n return {\n "hosts": host_list,\n "connections": conn_list,\n "stats": {\n "total_hosts": len(host_list),\n "total_datasets_scanned": len(all_datasets),\n "datasets_with_hosts": ds_with_hosts,\n "total_rows_scanned": total_rows,\n "hosts_with_ips": sum(1 for h in host_list if h[\'ips\']),\n "hosts_with_users": sum(1 for h in host_list if h[\'users\']),\n "row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,\n "sampled_mode": sampled,\n },\n }\n')
p.write_text(t,encoding='utf-8')
print('patched backend inventory performance settings')

220
_perf_patch_networkmap.py Normal file
View File

@@ -0,0 +1,220 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
# constants
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
t=t.replace(
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n",
"const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n")
# drawBackground signature
t_old='''function drawBackground(
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
) {
'''
if t_old in t:
t=t.replace(t_old,
'''function drawBackground(
ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,
simplify: boolean,
) {
''')
# skip grid when simplify
if 'if (!simplify) {' not in t:
t=t.replace(
''' ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
''',
''' if (!simplify) {
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
}
''')
# drawEdges signature
t=t.replace('''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
) {
for (const e of graph.edges) {
''',
'''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
simplify: boolean,
) {
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
const e = graph.edges[ei];
''')
# simplify edge path
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);',
'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
# reduce glow when simplify
t=t.replace(''' ctx.save();
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
ctx.stroke(); ctx.restore();
''',
''' if (!simplify) {
ctx.save();
ctx.shadowColor = 'rgba(96,165,250,0.5)'; ctx.shadowBlur = 8;
ctx.strokeStyle = 'rgba(96,165,250,0.3)';
ctx.lineWidth = Math.min(5, 2 + e.weight * 0.2);
ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }
ctx.stroke(); ctx.restore();
}
''')
# drawLabels signature and early return
t=t.replace('''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
) {
''',
'''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
''')
if 'if (simplify && !search && !hovered && !selected) {' not in t:
t=t.replace(' const dimmed = search.length > 0;\n',
' const dimmed = search.length > 0;\n if (simplify && !search && !hovered && !selected) {\n return;\n }\n')
# drawGraph adapt
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
ctx.restore();
''',
''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
drawBackground(ctx, w, h, vp, dpr, simplify);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
ctx.restore();
''')
# hover RAF ref
if 'const hoverRafRef = useRef<number>(0);' not in t:
t=t.replace(' const graphRef = useRef<Graph | null>(null);\n', ' const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
# throttle hover hit test on mousemove
old_mm=''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
setHovered(node?.id ?? null);
}, [graph, redraw, startAnimLoop]);
'''
new_mm=''' cancelAnimationFrame(hoverRafRef.current);
const clientX = e.clientX;
const clientY = e.clientY;
hoverRafRef.current = requestAnimationFrame(() => {
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
});
}, [graph, redraw, startAnimLoop]);
'''
if old_mm in t:
t=t.replace(old_mm,new_mm)
# cleanup hover raf on unmount in existing animation cleanup effect
if 'cancelAnimationFrame(hoverRafRef.current);' not in t:
t=t.replace(''' useEffect(() => {
if (graph) startAnimLoop();
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
}, [graph, startAnimLoop]);
''',
''' useEffect(() => {
if (graph) startAnimLoop();
return () => {
cancelAnimationFrame(animFrameRef.current);
cancelAnimationFrame(hoverRafRef.current);
isAnimatingRef.current = false;
};
}, [graph, startAnimLoop]);
''')
# connectedNodes optimization map
if 'const nodeById = useMemo(() => {' not in t:
t=t.replace(''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const connectedNodes = useMemo(() => {
''',
''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const nodeById = useMemo(() => {
const m = new Map<string, GNode>();
if (!graph) return m;
for (const n of graph.nodes) m.set(n.id, n);
return m;
}, [graph]);
const connectedNodes = useMemo(() => {
''')
t=t.replace(''' const n = graph.nodes.find(x => x.id === e.target);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
} else if (e.target === selectedNode.id) {
const n = graph.nodes.find(x => x.id === e.source);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
''',
''' const n = nodeById.get(e.target);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
} else if (e.target === selectedNode.id) {
const n = nodeById.get(e.source);
if (n) neighbors.push({ id: n.id, type: n.meta.type, weight: e.weight });
''')
t=t.replace(' }, [selectedNode, graph]);\n', ' }, [selectedNode, graph, nodeById]);\n')
p.write_text(t,encoding='utf-8')
print('patched NetworkMap adaptive render + hover throttle')

153
_perf_patch_networkmap2.py Normal file
View File

@@ -0,0 +1,153 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/frontend/src/components/NetworkMap.tsx')
t=p.read_text(encoding='utf-8')
if 'RENDER_SIMPLIFY_NODE_THRESHOLD' not in t:
t=t.replace('const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\n', 'const LARGE_HUNT_SUBGRAPH_EDGES = 2500;\nconst RENDER_SIMPLIFY_NODE_THRESHOLD = 220;\nconst RENDER_SIMPLIFY_EDGE_THRESHOLD = 1200;\nconst EDGE_DRAW_TARGET = 1000;\n')
t=t.replace('function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n) {', 'function drawBackground(\n ctx: CanvasRenderingContext2D, w: number, h: number, vp: Viewport, dpr: number,\n simplify: boolean,\n) {')
t=t.replace(''' ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
''',''' if (!simplify) {
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
const startX = -vp.x / vp.scale - GRID_SPACING;
const startY = -vp.y / vp.scale - GRID_SPACING;
const endX = startX + w / (vp.scale * dpr) + GRID_SPACING * 2;
const endY = startY + h / (vp.scale * dpr) + GRID_SPACING * 2;
ctx.fillStyle = GRID_DOT_COLOR;
for (let gx = Math.floor(startX / GRID_SPACING) * GRID_SPACING; gx < endX; gx += GRID_SPACING) {
for (let gy = Math.floor(startY / GRID_SPACING) * GRID_SPACING; gy < endY; gy += GRID_SPACING) {
ctx.beginPath(); ctx.arc(gx, gy, 1, 0, Math.PI * 2); ctx.fill();
}
}
ctx.restore();
}
''')
t=t.replace('''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
) {
for (const e of graph.edges) {
''','''function drawEdges(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
nodeMap: Map<string, GNode>, animTime: number,
simplify: boolean,
) {
const edgeStep = simplify ? Math.max(1, Math.ceil(graph.edges.length / EDGE_DRAW_TARGET)) : 1;
for (let ei = 0; ei < graph.edges.length; ei += edgeStep) {
const e = graph.edges[ei];
''')
t=t.replace('ctx.beginPath(); ctx.moveTo(a.x, a.y); ctx.quadraticCurveTo(cpx, cpy, b.x, b.y);', 'ctx.beginPath(); ctx.moveTo(a.x, a.y); if (simplify) { ctx.lineTo(b.x, b.y); } else { ctx.quadraticCurveTo(cpx, cpy, b.x, b.y); }')
t=t.replace('''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
) {
const dimmed = search.length > 0;
''','''function drawLabels(
ctx: CanvasRenderingContext2D, graph: Graph,
hovered: string | null, selected: string | null,
search: string, matchSet: Set<string>, vp: Viewport,
simplify: boolean,
) {
const dimmed = search.length > 0;
if (simplify && !search && !hovered && !selected) {
return;
}
''')
t=t.replace(''' drawBackground(ctx, w, h, vp, dpr);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp);
ctx.restore();
''',''' const simplify = graph.nodes.length > RENDER_SIMPLIFY_NODE_THRESHOLD || graph.edges.length > RENDER_SIMPLIFY_EDGE_THRESHOLD;
drawBackground(ctx, w, h, vp, dpr, simplify);
ctx.save();
ctx.translate(vp.x * dpr, vp.y * dpr);
ctx.scale(vp.scale * dpr, vp.scale * dpr);
drawEdges(ctx, graph, hovered, selected, nodeMap, animTime, simplify);
drawNodes(ctx, graph, hovered, selected, search, matchSet);
drawLabels(ctx, graph, hovered, selected, search, matchSet, vp, simplify);
ctx.restore();
''')
if 'const hoverRafRef = useRef<number>(0);' not in t:
t=t.replace('const graphRef = useRef<Graph | null>(null);\n', 'const graphRef = useRef<Graph | null>(null);\n const hoverRafRef = useRef<number>(0);\n')
t=t.replace(''' const node = hitTest(graph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
setHovered(node?.id ?? null);
}, [graph, redraw, startAnimLoop]);
''',''' cancelAnimationFrame(hoverRafRef.current);
const clientX = e.clientX;
const clientY = e.clientY;
hoverRafRef.current = requestAnimationFrame(() => {
const node = hitTest(graph, canvasRef.current as HTMLCanvasElement, clientX, clientY, vpRef.current);
setHovered(prev => (prev === (node?.id ?? null) ? prev : (node?.id ?? null)));
});
}, [graph, redraw, startAnimLoop]);
''')
t=t.replace(''' useEffect(() => {
if (graph) startAnimLoop();
return () => { cancelAnimationFrame(animFrameRef.current); isAnimatingRef.current = false; };
}, [graph, startAnimLoop]);
''',''' useEffect(() => {
if (graph) startAnimLoop();
return () => {
cancelAnimationFrame(animFrameRef.current);
cancelAnimationFrame(hoverRafRef.current);
isAnimatingRef.current = false;
};
}, [graph, startAnimLoop]);
''')
if 'const nodeById = useMemo(() => {' not in t:
t=t.replace(''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const connectedNodes = useMemo(() => {
''',''' const connectionCount = selectedNode && graph
? graph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
: 0;
const nodeById = useMemo(() => {
const m = new Map<string, GNode>();
if (!graph) return m;
for (const n of graph.nodes) m.set(n.id, n);
return m;
}, [graph]);
const connectedNodes = useMemo(() => {
''')
t=t.replace('const n = graph.nodes.find(x => x.id === e.target);','const n = nodeById.get(e.target);')
t=t.replace('const n = graph.nodes.find(x => x.id === e.source);','const n = nodeById.get(e.source);')
t=t.replace(' }, [selectedNode, graph]);',' }, [selectedNode, graph, nodeById]);')
p.write_text(t,encoding='utf-8')
print('patched NetworkMap performance')

View File

@@ -0,0 +1,227 @@
from pathlib import Path
p=Path(r'd:/Projects/Dev/ThreatHunt/backend/app/services/host_inventory.py')
t=p.read_text(encoding='utf-8')
start=t.index('async def build_host_inventory(')
# find end of function by locating '\n\n' before EOF after ' }\n'
end=t.index('\n\n', start)
# need proper end: first double newline after function may occur in docstring? compute by searching for '\n\n' after ' }\n' near end
ret_idx=t.rfind(' }')
# safer locate end as last occurrence of '\n }\n' after start, then function ends next newline
end=t.find('\n\n', ret_idx)
if end==-1:
end=len(t)
new_func='''async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
"""Build a deduplicated host inventory from all datasets in a hunt.
Returns dict with 'hosts', 'connections', and 'stats'.
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
"""
ds_result = await db.execute(
select(Dataset).where(Dataset.hunt_id == hunt_id)
)
all_datasets = ds_result.scalars().all()
if not all_datasets:
return {"hosts": [], "connections": [], "stats": {
"total_hosts": 0, "total_datasets_scanned": 0,
"total_rows_scanned": 0,
}}
hosts: dict[str, dict] = {} # fqdn -> host record
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
for ds in all_datasets:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
cols = _identify_columns(ds)
if not cols['fqdn'] and not cols['host_id']:
continue
ds_with_hosts += 1
batch_size = 5000
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index)
.limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
fqdn = ''
for c in cols['fqdn']:
fqdn = _clean(data.get(c))
if fqdn:
break
client_id = ''
for c in cols['host_id']:
client_id = _clean(data.get(c))
if client_id:
break
if not fqdn and not client_id:
continue
host_key = fqdn or client_id
if host_key not in hosts:
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
hosts[host_key] = {
'id': host_key,
'hostname': short or client_id,
'fqdn': fqdn,
'client_id': client_id,
'ips': set(),
'os': '',
'users': set(),
'datasets': set(),
'row_count': 0,
}
h = hosts[host_key]
h['datasets'].add(ds.name)
h['row_count'] += 1
if client_id and not h['client_id']:
h['client_id'] = client_id
for c in cols['username']:
u = _extract_username(_clean(data.get(c)))
if u:
h['users'].add(u)
for c in cols['local_ip']:
ip = _clean(data.get(c))
if _is_valid_ip(ip):
h['ips'].add(ip)
ip_to_host[ip] = host_key
for c in cols['os']:
ov = _clean(data.get(c))
if ov and not h['os']:
h['os'] = ov
for c in cols['remote_ip']:
rip = _clean(data.get(c))
if _is_valid_ip(rip):
rport = ''
for pc in cols['remote_port']:
rport = _clean(data.get(pc))
if rport:
break
conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory sampling for dataset %s (%d rows scanned)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts
for h in hosts.values():
if not h['os'] and h['fqdn']:
h['os'] = _infer_os(h['fqdn'])
h['ips'] = sorted(h['ips'])
h['users'] = sorted(h['users'])
h['datasets'] = sorted(h['datasets'])
# Build connections, resolving IPs to host keys
conn_list = []
seen = set()
for (src, dst_ip, dst_port), cnt in connections.items():
if dst_ip in _IGNORE_IPS:
continue
dst_host = ip_to_host.get(dst_ip, '')
if dst_host == src:
continue
key = tuple(sorted([src, dst_host or dst_ip]))
if key in seen:
continue
seen.add(key)
conn_list.append({
'source': src,
'target': dst_host or dst_ip,
'target_ip': dst_ip,
'port': dst_port,
'count': cnt,
})
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
return {
"hosts": host_list,
"connections": conn_list,
"stats": {
"total_hosts": len(host_list),
"total_datasets_scanned": len(all_datasets),
"datasets_with_hosts": ds_with_hosts,
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
},
}
'''
out=t[:start]+new_func+t[end:]
p.write_text(out,encoding='utf-8')
print('replaced build_host_inventory with hard-budget fast mode')

View File

@@ -0,0 +1,112 @@
"""add processing_status and AI analysis tables
Revision ID: a1b2c3d4e5f6
Revises: 98ab619418bc
Create Date: 2026-02-19 18:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, Sequence[str], None] = "98ab619418bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add columns to datasets table
with op.batch_alter_table("datasets") as batch_op:
batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready"))
batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True))
batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True))
batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True))
batch_op.create_index("ix_datasets_status", ["processing_status"])
# Create triage_results table
op.create_table(
"triage_results",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("row_start", sa.Integer(), nullable=False),
sa.Column("row_end", sa.Integer(), nullable=False),
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"),
sa.Column("findings", sa.JSON(), nullable=True),
sa.Column("suspicious_indicators", sa.JSON(), nullable=True),
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
sa.Column("model_used", sa.String(128), nullable=True),
sa.Column("node_used", sa.String(64), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create host_profiles table
op.create_table(
"host_profiles",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("hostname", sa.String(256), nullable=False),
sa.Column("fqdn", sa.String(512), nullable=True),
sa.Column("client_id", sa.String(64), nullable=True),
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"),
sa.Column("artifact_summary", sa.JSON(), nullable=True),
sa.Column("timeline_summary", sa.Text(), nullable=True),
sa.Column("suspicious_findings", sa.JSON(), nullable=True),
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
sa.Column("llm_analysis", sa.Text(), nullable=True),
sa.Column("model_used", sa.String(128), nullable=True),
sa.Column("node_used", sa.String(64), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create hunt_reports table
op.create_table(
"hunt_reports",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("status", sa.String(20), nullable=False, server_default="pending"),
sa.Column("exec_summary", sa.Text(), nullable=True),
sa.Column("full_report", sa.Text(), nullable=True),
sa.Column("findings", sa.JSON(), nullable=True),
sa.Column("recommendations", sa.JSON(), nullable=True),
sa.Column("mitre_mapping", sa.JSON(), nullable=True),
sa.Column("ioc_table", sa.JSON(), nullable=True),
sa.Column("host_risk_summary", sa.JSON(), nullable=True),
sa.Column("models_used", sa.JSON(), nullable=True),
sa.Column("generation_time_ms", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create anomaly_results table
op.create_table(
"anomaly_results",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True),
sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("distance_from_centroid", sa.Float(), nullable=True),
sa.Column("cluster_id", sa.Integer(), nullable=True),
sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"),
sa.Column("explanation", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
def downgrade() -> None:
op.drop_table("anomaly_results")
op.drop_table("hunt_reports")
op.drop_table("host_profiles")
op.drop_table("triage_results")
with op.batch_alter_table("datasets") as batch_op:
batch_op.drop_index("ix_datasets_status")
batch_op.drop_column("file_path")
batch_op.drop_column("error_message")
batch_op.drop_column("artifact_type")
batch_op.drop_column("processing_status")

View File

@@ -0,0 +1,72 @@
"""add cases and activity logs
Revision ID: a3b1c2d4e5f6
Revises: 98ab619418bc
Create Date: 2025-01-01 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "a3b1c2d4e5f6"
down_revision: Union[str, None] = "98ab619418bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"cases",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("title", sa.String(512), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("severity", sa.String(16), server_default="medium"),
sa.Column("tlp", sa.String(16), server_default="amber"),
sa.Column("pap", sa.String(16), server_default="amber"),
sa.Column("status", sa.String(24), server_default="open"),
sa.Column("priority", sa.Integer, server_default="2"),
sa.Column("assignee", sa.String(128), nullable=True),
sa.Column("tags", sa.JSON, nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("owner_id", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("mitre_techniques", sa.JSON, nullable=True),
sa.Column("iocs", sa.JSON, nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_cases_hunt", "cases", ["hunt_id"])
op.create_index("ix_cases_status", "cases", ["status"])
op.create_table(
"case_tasks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id", ondelete="CASCADE"), nullable=False),
sa.Column("title", sa.String(512), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("status", sa.String(24), server_default="todo"),
sa.Column("assignee", sa.String(128), nullable=True),
sa.Column("order", sa.Integer, server_default="0"),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_case_tasks_case", "case_tasks", ["case_id"])
op.create_table(
"activity_logs",
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("entity_type", sa.String(32), nullable=False),
sa.Column("entity_id", sa.String(32), nullable=False),
sa.Column("action", sa.String(64), nullable=False),
sa.Column("details", sa.JSON, nullable=True),
sa.Column("user_id", sa.String(32), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_activity_entity", "activity_logs", ["entity_type", "entity_id"])
def downgrade() -> None:
op.drop_table("activity_logs")
op.drop_table("case_tasks")
op.drop_table("cases")

View File

@@ -0,0 +1,78 @@
"""add playbooks, playbook_steps, saved_searches tables
Revision ID: b2c3d4e5f6a7
Revises: a1b2c3d4e5f6
Create Date: 2026-02-21 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "b2c3d4e5f6a7"
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add display_name to users table
with op.batch_alter_table("users") as batch_op:
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
# Create playbooks table
op.create_table(
"playbooks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("is_template", sa.Boolean(), server_default="0"),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("status", sa.String(20), server_default="active"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create playbook_steps table
op.create_table(
"playbook_steps",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
sa.Column("order_index", sa.Integer(), nullable=False),
sa.Column("title", sa.String(256), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("step_type", sa.String(32), server_default="manual"),
sa.Column("target_route", sa.String(256), nullable=True),
sa.Column("is_completed", sa.Boolean(), server_default="0"),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("notes", sa.Text(), nullable=True),
)
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
# Create saved_searches table
op.create_table(
"saved_searches",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("search_type", sa.String(32), nullable=False),
sa.Column("query_params", sa.JSON(), nullable=False),
sa.Column("threshold", sa.Float(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_result_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
def downgrade() -> None:
op.drop_table("saved_searches")
op.drop_table("playbook_steps")
op.drop_table("playbooks")
with op.batch_alter_table("users") as batch_op:
batch_op.drop_column("display_name")

View File

@@ -0,0 +1,63 @@
"""add alerts and alert_rules tables
Revision ID: b4c2d3e5f6a7
Revises: a3b1c2d4e5f6
Create Date: 2025-01-01 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision: str = "b4c2d3e5f6a7"
down_revision: Union[str, None] = "a3b1c2d4e5f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"alerts",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("title", sa.String(512), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("severity", sa.String(16), server_default="medium"),
sa.Column("status", sa.String(24), server_default="new"),
sa.Column("analyzer", sa.String(64), nullable=False),
sa.Column("score", sa.Float, server_default="0"),
sa.Column("evidence", sa.JSON, nullable=True),
sa.Column("mitre_technique", sa.String(32), nullable=True),
sa.Column("tags", sa.JSON, nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id"), nullable=True),
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
sa.Column("assignee", sa.String(128), nullable=True),
sa.Column("acknowledged_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_alerts_severity", "alerts", ["severity"])
op.create_index("ix_alerts_status", "alerts", ["status"])
op.create_index("ix_alerts_hunt", "alerts", ["hunt_id"])
op.create_index("ix_alerts_dataset", "alerts", ["dataset_id"])
op.create_table(
"alert_rules",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("analyzer", sa.String(64), nullable=False),
sa.Column("config", sa.JSON, nullable=True),
sa.Column("severity_override", sa.String(16), nullable=True),
sa.Column("enabled", sa.Boolean, server_default=sa.text("1")),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_alert_rules_analyzer", "alert_rules", ["analyzer"])
def downgrade() -> None:
op.drop_table("alert_rules")
op.drop_table("alerts")

View File

@@ -0,0 +1,48 @@
"""add processing_tasks table
Revision ID: c3d4e5f6a7b8
Revises: b2c3d4e5f6a7
Create Date: 2026-02-22 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "c3d4e5f6a7b8"
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"processing_tasks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
sa.Column("job_id", sa.String(64), nullable=True),
sa.Column("stage", sa.String(64), nullable=False),
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("message", sa.Text(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
def downgrade() -> None:
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
op.drop_table("processing_tasks")

View File

@@ -0,0 +1,54 @@
"""add notebooks and playbook_runs tables
Revision ID: c5d3e4f6a7b8
Revises: b4c2d3e5f6a7
Create Date: 2025-01-01 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "c5d3e4f6a7b8"
down_revision: Union[str, None] = "b4c2d3e5f6a7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"notebooks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("title", sa.String(512), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("cells", sa.JSON, nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
sa.Column("owner_id", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("tags", sa.JSON, nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_notebooks_hunt", "notebooks", ["hunt_id"])
op.create_table(
"playbook_runs",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("playbook_name", sa.String(256), nullable=False),
sa.Column("status", sa.String(24), server_default="in-progress"),
sa.Column("current_step", sa.Integer, server_default="1"),
sa.Column("total_steps", sa.Integer, server_default="0"),
sa.Column("step_results", sa.JSON, nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
sa.Column("started_by", sa.String(128), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index("ix_playbook_runs_hunt", "playbook_runs", ["hunt_id"])
op.create_index("ix_playbook_runs_status", "playbook_runs", ["status"])
def downgrade() -> None:
op.drop_table("playbook_runs")
op.drop_table("notebooks")

View File

@@ -1,16 +1,18 @@
"""Analyst-assist agent module for ThreatHunt. """Analyst-assist agent module for ThreatHunt.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses. Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only and do not execute actions or modify data. Agents are advisory only and do not execute actions or modify data.
""" """
from .core import ThreatHuntAgent from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
__all__ = [ __all__ = [
"ThreatHuntAgent", "ThreatHuntAgent",
"LLMProvider", "AgentContext",
"LocalProvider", "AgentResponse",
"NetworkedProvider", "Perspective",
"OnlineProvider", "OllamaProvider",
"OpenWebUIProvider",
"EmbeddingProvider",
] ]

View File

@@ -1,208 +0,0 @@
"""Core ThreatHunt analyst-assist agent.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only - no execution, no alerts, no data modifications.
"""
import logging
from typing import Optional
from pydantic import BaseModel, Field
from .providers import LLMProvider, get_provider
logger = logging.getLogger(__name__)
class AgentContext(BaseModel):
"""Context for agent guidance requests."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
host_identifier: Optional[str] = Field(
None, description="Host name, IP, or identifier"
)
data_summary: Optional[str] = Field(
None, description="Brief description of uploaded data"
)
conversation_history: Optional[list[dict]] = Field(
default_factory=list, description="Previous messages in conversation"
)
class AgentResponse(BaseModel):
"""Response from analyst-assist agent."""
guidance: str = Field(..., description="Advisory guidance for analyst")
confidence: float = Field(
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
)
suggested_pivots: list[str] = Field(
default_factory=list, description="Suggested analytical directions"
)
suggested_filters: list[str] = Field(
default_factory=list, description="Suggested data filters or queries"
)
caveats: Optional[str] = Field(
None, description="Assumptions, limitations, or caveats"
)
reasoning: Optional[str] = Field(
None, description="Explanation of how guidance was generated"
)
class ThreatHuntAgent:
"""Analyst-assist agent for ThreatHunt.
Provides guidance on:
- Interpreting CSV artifact data
- Suggesting analytical pivots and filters
- Forming and testing hypotheses
Policy:
- Advisory guidance only (no execution)
- No database or schema changes
- No alert escalation
- Transparent reasoning
"""
def __init__(self, provider: Optional[LLMProvider] = None):
"""Initialize agent with LLM provider.
Args:
provider: LLM provider instance. If None, uses get_provider() with auto mode.
"""
if provider is None:
try:
provider = get_provider("auto")
except RuntimeError as e:
logger.warning(f"Could not initialize default provider: {e}")
provider = None
self.provider = provider
self.system_prompt = self._build_system_prompt()
def _build_system_prompt(self) -> str:
"""Build the system prompt that governs agent behavior."""
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
Your role:
- Interpret and explain CSV artifact data from Velociraptor
- Suggest analytical pivots, filters, and hypotheses
- Highlight anomalies, patterns, or points of interest
- Guide analysts without replacing their judgment
Your constraints:
- You ONLY provide guidance and suggestions
- You do NOT execute actions or tools
- You do NOT modify data or escalate alerts
- You do NOT make autonomous decisions
- You ONLY analyze data presented to you
- You explain your reasoning transparently
- You acknowledge limitations and assumptions
- You suggest next investigative steps
When responding:
1. Start with a clear, direct answer to the query
2. Explain your reasoning based on the data context provided
3. Suggest 2-4 analytical pivots the analyst might explore
4. Suggest 2-4 data filters or queries that might be useful
5. Include relevant caveats or assumptions
6. Be honest about what you cannot determine from the data
Remember: The analyst is the decision-maker. You are an assistant."""
async def assist(self, context: AgentContext) -> AgentResponse:
"""Provide guidance on artifact data and analysis.
Args:
context: Request context including query and data context.
Returns:
Guidance response with suggestions and reasoning.
Raises:
RuntimeError: If no provider is available.
"""
if not self.provider:
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
# Build prompt with context
prompt = self._build_prompt(context)
try:
# Get guidance from LLM provider
guidance = await self.provider.generate(prompt, max_tokens=1024)
# Parse response into structured format
response = self._parse_response(guidance, context)
logger.info(
f"Agent assisted with query: {context.query[:50]}... "
f"(dataset: {context.dataset_name})"
)
return response
except Exception as e:
logger.error(f"Error generating guidance: {e}")
raise
def _build_prompt(self, context: AgentContext) -> str:
"""Build the prompt for the LLM."""
prompt_parts = [
f"Analyst query: {context.query}",
]
if context.dataset_name:
prompt_parts.append(f"Dataset: {context.dataset_name}")
if context.artifact_type:
prompt_parts.append(f"Artifact type: {context.artifact_type}")
if context.host_identifier:
prompt_parts.append(f"Host: {context.host_identifier}")
if context.data_summary:
prompt_parts.append(f"Data summary: {context.data_summary}")
if context.conversation_history:
prompt_parts.append("\nConversation history:")
for msg in context.conversation_history[-5:]: # Last 5 messages for context
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
return "\n".join(prompt_parts)
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
"""Parse LLM response into structured format.
Note: This is a simplified parser. In production, use structured output
from the LLM (JSON mode, function calling, etc.) for better reliability.
"""
# For now, return a structured response based on the raw guidance
# In production, parse JSON or use structured output from LLM
return AgentResponse(
guidance=response_text,
confidence=0.8, # Placeholder
suggested_pivots=[
"Analyze temporal patterns",
"Cross-reference with known indicators",
"Examine outliers in the dataset",
"Compare with baseline behavior",
],
suggested_filters=[
"Filter by high-risk indicators",
"Sort by timestamp for timeline analysis",
"Group by host or user",
"Filter by anomaly score",
],
caveats="Guidance is based on available data context. "
"Analysts should verify findings with additional sources.",
reasoning="Analysis generated based on artifact data patterns and analyst query.",
)

View File

@@ -1,190 +0,0 @@
"""Pluggable LLM provider interface for analyst-assist agents.
Supports three provider types:
- Local: On-device or on-prem models
- Networked: Shared internal inference services
- Online: External hosted APIs
"""
import os
from abc import ABC, abstractmethod
from typing import Optional
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
@abstractmethod
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate a response from the LLM.
Args:
prompt: The input prompt
max_tokens: Maximum tokens in response
Returns:
Generated text response
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if provider backend is available."""
pass
class LocalProvider(LLMProvider):
"""Local LLM provider (on-device or on-prem models)."""
def __init__(self, model_path: Optional[str] = None):
"""Initialize local provider.
Args:
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
"""
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
self.model = None
def is_available(self) -> bool:
"""Check if local model is available."""
if not self.model_path:
return False
# In production, would verify model file exists and can be loaded
return os.path.exists(str(self.model_path))
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using local model.
Note: This is a placeholder. In production, integrate with:
- llama-cpp-python for GGML models
- Ollama API
- vLLM
- Other local inference engines
"""
if not self.is_available():
raise RuntimeError("Local model not available")
# Placeholder implementation
return f"[Local model response to: {prompt[:50]}...]"
class NetworkedProvider(LLMProvider):
"""Networked LLM provider (shared internal inference services)."""
def __init__(
self,
api_endpoint: Optional[str] = None,
api_key: Optional[str] = None,
model_name: str = "default",
):
"""Initialize networked provider.
Args:
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
model_name: Model name/ID on the service.
"""
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
self.model_name = model_name
def is_available(self) -> bool:
"""Check if networked service is available."""
return bool(self.api_endpoint)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using networked service.
Note: This is a placeholder. In production, integrate with:
- Internal inference service API
- LLM inference container cluster
- Enterprise inference gateway
"""
if not self.is_available():
raise RuntimeError("Networked service not available")
# Placeholder implementation
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
class OnlineProvider(LLMProvider):
"""Online LLM provider (external hosted APIs)."""
def __init__(
self,
api_provider: str = "openai",
api_key: Optional[str] = None,
model_name: Optional[str] = None,
):
"""Initialize online provider.
Args:
api_provider: Provider name (openai, anthropic, google, etc.)
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
"""
self.api_provider = api_provider
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
self.model_name = model_name or os.getenv(
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
)
def is_available(self) -> bool:
"""Check if online API is available."""
return bool(self.api_key)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using online API.
Note: This is a placeholder. In production, integrate with:
- OpenAI API (GPT-3.5, GPT-4, etc.)
- Anthropic Claude API
- Google Gemini API
- Other hosted LLM services
"""
if not self.is_available():
raise RuntimeError("Online API not available or API key not set")
# Placeholder implementation
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
def get_provider(provider_type: str = "auto") -> LLMProvider:
"""Get an LLM provider based on configuration.
Args:
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
'auto' attempts to use the first available provider in order:
local -> networked -> online.
Returns:
Configured LLM provider instance.
Raises:
RuntimeError: If no provider is available.
"""
# Explicit provider selection
if provider_type == "local":
provider = LocalProvider()
elif provider_type == "networked":
provider = NetworkedProvider()
elif provider_type == "online":
provider = OnlineProvider()
elif provider_type == "auto":
# Try providers in order of preference
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
provider = Provider()
if provider.is_available():
return provider
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
if not provider.is_available():
raise RuntimeError(f"{provider_type} provider not available")
return provider

View File

@@ -1,170 +0,0 @@
"""API routes for analyst-assist agent."""
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
from app.agents.config import AgentConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/agent", tags=["agent"])
# Global agent instance (lazy-loaded)
_agent: ThreatHuntAgent | None = None
def get_agent() -> ThreatHuntAgent:
"""Get or create the agent instance."""
global _agent
if _agent is None:
if not AgentConfig.is_agent_enabled():
raise HTTPException(
status_code=503,
detail="Analyst-assist agent is not configured. "
"Please configure an LLM provider.",
)
_agent = ThreatHuntAgent()
return _agent
class AssistRequest(BaseModel):
"""Request for agent assistance."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: str | None = Field(
None, description="Name of CSV dataset being analyzed"
)
artifact_type: str | None = Field(
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
)
host_identifier: str | None = Field(
None, description="Host name, IP address, or identifier"
)
data_summary: str | None = Field(
None, description="Brief summary or context about the uploaded data"
)
conversation_history: list[dict] | None = Field(
None, description="Previous messages for context"
)
class AssistResponse(BaseModel):
"""Response with agent guidance."""
guidance: str
confidence: float
suggested_pivots: list[str]
suggested_filters: list[str]
caveats: str | None = None
reasoning: str | None = None
@router.post(
"/assist",
response_model=AssistResponse,
summary="Get analyst-assist guidance",
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
"Agent provides advisory guidance only - no execution.",
)
async def agent_assist(request: AssistRequest) -> AssistResponse:
"""Provide analyst-assist guidance on artifact data.
The agent will:
- Explain and interpret the provided data context
- Suggest analytical pivots the analyst might explore
- Suggest data filters or queries that might be useful
- Highlight assumptions, limitations, and caveats
The agent will NOT:
- Execute any tools or actions
- Escalate findings to alerts
- Modify any data or schema
- Make autonomous decisions
Args:
request: Assistance request with query and context
Returns:
Guidance response with suggestions and reasoning
Raises:
HTTPException: If agent is not configured (503) or request fails
"""
try:
agent = get_agent()
# Build context
context = AgentContext(
query=request.query,
dataset_name=request.dataset_name,
artifact_type=request.artifact_type,
host_identifier=request.host_identifier,
data_summary=request.data_summary,
conversation_history=request.conversation_history or [],
)
# Get guidance
response = await agent.assist(context)
logger.info(
f"Agent assisted analyst with query: {request.query[:50]}... "
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
)
return AssistResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
)
except RuntimeError as e:
logger.error(f"Agent error: {e}")
raise HTTPException(
status_code=503,
detail=f"Agent unavailable: {str(e)}",
)
except Exception as e:
logger.exception(f"Unexpected error in agent_assist: {e}")
raise HTTPException(
status_code=500,
detail="Error generating guidance. Please try again.",
)
@router.get(
"/health",
summary="Check agent health",
description="Check if agent is configured and ready to assist.",
)
async def agent_health() -> dict:
"""Check agent availability and configuration.
Returns:
Health status with configuration details
"""
try:
agent = get_agent()
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
return {
"status": "healthy",
"provider": provider_type,
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
}
except HTTPException:
return {
"status": "unavailable",
"reason": "No LLM provider configured",
"configured_providers": {
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
"online": bool(AgentConfig.ONLINE_API_KEY),
},
}

View File

@@ -1,4 +1,4 @@
"""API routes for analyst-assist agent v2. """API routes for analyst-assist agent v2.
Supports quick, deep, and debate modes with streaming. Supports quick, deep, and debate modes with streaming.
Conversations are persisted to the database. Conversations are persisted to the database.
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
import json import json
import logging import logging
import re
import time
from collections import Counter
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
from app.db import get_db from app.db import get_db
from app.db.models import Conversation, Message from app.db.models import Conversation, Message, Dataset, KeywordTheme
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from app.agents.providers_v2 import check_all_nodes from app.agents.providers_v2 import check_all_nodes
from app.agents.registry import registry from app.agents.registry import registry
from app.services.sans_rag import sans_rag from app.services.sans_rag import sans_rag
from app.services.scanner import KeywordScanner
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
return _agent return _agent
# ── Request / Response models ───────────────────────────────────────── # Request / Response models
class AssistRequest(BaseModel): class AssistRequest(BaseModel):
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
model_override: str | None = None model_override: str | None = None
conversation_id: str | None = Field(None, description="Persist messages to this conversation") conversation_id: str | None = Field(None, description="Persist messages to this conversation")
hunt_id: str | None = None hunt_id: str | None = None
execution_preference: str = Field(default="auto", description="auto | force | off")
learning_mode: bool = False
class AssistResponseModel(BaseModel): class AssistResponseModel(BaseModel):
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
node_used: str = "" node_used: str = ""
latency_ms: int = 0 latency_ms: int = 0
perspectives: list[dict] | None = None perspectives: list[dict] | None = None
execution: dict | None = None
conversation_id: str | None = None conversation_id: str | None = None
# ── Routes ──────────────────────────────────────────────────────────── POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
POLICY_QUERY_TERMS = {
"policy", "violating", "violation", "browser history", "web history",
"domain", "domains", "adult", "gambling", "piracy", "aup",
}
WEB_DATASET_HINTS = {
"web", "history", "browser", "url", "visited_url", "domain", "title",
}
def _is_policy_domain_query(query: str) -> bool:
q = (query or "").lower()
if not q:
return False
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
def _should_execute_policy_scan(request: AssistRequest) -> bool:
pref = (request.execution_preference or "auto").strip().lower()
if pref == "off":
return False
if pref == "force":
return True
return _is_policy_domain_query(request.query)
def _extract_domain(value: str | None) -> str | None:
if not value:
return None
text = value.strip()
if not text:
return None
try:
parsed = urlparse(text)
if parsed.netloc:
return parsed.netloc.lower()
except Exception:
pass
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
return m.group(0).lower() if m else None
def _dataset_score(ds: Dataset) -> int:
score = 0
name = (ds.name or "").lower()
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
for h in WEB_DATASET_HINTS:
if h in name:
score += 2
if h in cols_l:
score += 3
if h in norm_vals_l:
score += 3
if "visited_url" in cols_l or "url" in cols_l:
score += 8
if "user" in cols_l or "username" in cols_l:
score += 2
if "clientid" in cols_l or "fqdn" in cols_l:
score += 2
if (ds.row_count or 0) > 0:
score += 1
return score
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
scanner = KeywordScanner(db)
theme_result = await db.execute(
select(KeywordTheme).where(
KeywordTheme.enabled == True, # noqa: E712
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
)
)
themes = list(theme_result.scalars().all())
theme_ids = [t.id for t in themes]
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
if request.hunt_id:
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
ds_result = await db.execute(ds_query)
candidates = list(ds_result.scalars().all())
if request.dataset_name:
needle = request.dataset_name.lower().strip()
candidates = [d for d in candidates if needle in (d.name or "").lower()]
scored = sorted(
((d, _dataset_score(d)) for d in candidates),
key=lambda x: x[1],
reverse=True,
)
selected = [d for d, s in scored if s > 0][:8]
dataset_ids = [d.id for d in selected]
if not dataset_ids:
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": 0,
"dataset_names": [],
"total_hits": 0,
"policy_hits": 0,
"top_user_hosts": [],
"top_domains": [],
"sample_hits": [],
"note": "No suitable browser/web-history datasets found in current scope.",
}
result = await scanner.scan(
dataset_ids=dataset_ids,
theme_ids=theme_ids or None,
scan_hunts=False,
scan_annotations=False,
scan_messages=False,
)
hits = result.get("hits", [])
user_host_counter = Counter()
domain_counter = Counter()
for h in hits:
user = h.get("username") or "(unknown-user)"
host = h.get("hostname") or "(unknown-host)"
user_host_counter[f"{user}|{host}"] += 1
dom = _extract_domain(h.get("matched_value"))
if dom:
domain_counter[dom] += 1
top_user_hosts = [
{"user_host": k, "count": v}
for k, v in user_host_counter.most_common(10)
]
top_domains = [
{"domain": k, "count": v}
for k, v in domain_counter.most_common(10)
]
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": len(dataset_ids),
"dataset_names": [d.name for d in selected],
"total_hits": int(result.get("total_hits", 0)),
"policy_hits": int(result.get("total_hits", 0)),
"rows_scanned": int(result.get("rows_scanned", 0)),
"top_user_hosts": top_user_hosts,
"top_domains": top_domains,
"sample_hits": hits[:20],
}
# Routes
@router.post( @router.post(
@@ -84,6 +252,76 @@ async def agent_assist(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> AssistResponseModel: ) -> AssistResponseModel:
try: try:
# Deterministic execution mode for policy-domain investigations.
if _should_execute_policy_scan(request):
t0 = time.monotonic()
exec_payload = await _run_policy_domain_execution(request, db)
latency_ms = int((time.monotonic() - t0) * 1000)
policy_hits = exec_payload.get("policy_hits", 0)
datasets_scanned = exec_payload.get("datasets_scanned", 0)
if policy_hits > 0:
guidance = (
f"Policy-violation scan complete: {policy_hits} hits across "
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
f"in execution results for triage."
)
confidence = 0.95
caveats = "Keyword-based matching can include false positives; validate with full URL context."
else:
guidance = (
f"No policy-violation hits found in current scope "
f"({datasets_scanned} dataset(s) scanned)."
)
confidence = 0.9
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
response = AssistResponseModel(
guidance=guidance,
confidence=confidence,
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
suggested_filters=[
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
"username != null",
"hostname != null",
],
caveats=caveats,
reasoning=(
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
if _is_policy_domain_query(request.query)
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
),
sans_references=["SANS FOR508", "SANS SEC504"],
model_used="execution:keyword_scanner",
node_used="local",
latency_ms=latency_ms,
execution=exec_payload,
)
conv_id = request.conversation_id
if conv_id or request.hunt_id:
conv_id = await _persist_conversation(
db,
conv_id,
request,
AgentResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
sans_references=response.sans_references,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
),
)
response.conversation_id = conv_id
return response
agent = get_agent() agent = get_agent()
context = AgentContext( context = AgentContext(
query=request.query, query=request.query,
@@ -97,6 +335,7 @@ async def agent_assist(
enrichment_summary=request.enrichment_summary, enrichment_summary=request.enrichment_summary,
mode=request.mode, mode=request.mode,
model_override=request.model_override, model_override=request.model_override,
learning_mode=request.learning_mode,
) )
response = await agent.assist(context) response = await agent.assist(context)
@@ -129,6 +368,7 @@ async def agent_assist(
} }
for p in response.perspectives for p in response.perspectives
] if response.perspectives else None, ] if response.perspectives else None,
execution=None,
conversation_id=conv_id, conversation_id=conv_id,
) )
@@ -208,7 +448,7 @@ async def list_models():
} }
# ── Conversation persistence ────────────────────────────────────────── # Conversation persistence
async def _persist_conversation( async def _persist_conversation(
@@ -263,3 +503,4 @@ async def _persist_conversation(
await db.flush() await db.flush()
return conv.id return conv.id

View File

@@ -0,0 +1,404 @@
"""API routes for alerts — CRUD, analyze triggers, and alert rules."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Alert, AlertRule, _new_id, _utcnow
from app.db.repositories.datasets import DatasetRepository
from app.services.analyzers import (
get_available_analyzers,
get_analyzer,
run_all_analyzers,
AlertCandidate,
)
from app.services.process_tree import _fetch_rows
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/alerts", tags=["alerts"])
# ── Pydantic models ──────────────────────────────────────────────────
class AlertUpdate(BaseModel):
status: Optional[str] = None
severity: Optional[str] = None
assignee: Optional[str] = None
case_id: Optional[str] = None
tags: Optional[list[str]] = None
class RuleCreate(BaseModel):
name: str
description: Optional[str] = None
analyzer: str
config: Optional[dict] = None
severity_override: Optional[str] = None
enabled: bool = True
hunt_id: Optional[str] = None
class RuleUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
config: Optional[dict] = None
severity_override: Optional[str] = None
enabled: Optional[bool] = None
class AnalyzeRequest(BaseModel):
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
analyzers: Optional[list[str]] = None # None = run all
config: Optional[dict] = None
auto_create: bool = True # automatically persist alerts
# ── Helpers ───────────────────────────────────────────────────────────
def _alert_to_dict(a: Alert) -> dict:
return {
"id": a.id,
"title": a.title,
"description": a.description,
"severity": a.severity,
"status": a.status,
"analyzer": a.analyzer,
"score": a.score,
"evidence": a.evidence or [],
"mitre_technique": a.mitre_technique,
"tags": a.tags or [],
"hunt_id": a.hunt_id,
"dataset_id": a.dataset_id,
"case_id": a.case_id,
"assignee": a.assignee,
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
"created_at": a.created_at.isoformat() if a.created_at else None,
"updated_at": a.updated_at.isoformat() if a.updated_at else None,
}
def _rule_to_dict(r: AlertRule) -> dict:
return {
"id": r.id,
"name": r.name,
"description": r.description,
"analyzer": r.analyzer,
"config": r.config,
"severity_override": r.severity_override,
"enabled": r.enabled,
"hunt_id": r.hunt_id,
"created_at": r.created_at.isoformat() if r.created_at else None,
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
}
# ── Alert CRUD ────────────────────────────────────────────────────────
@router.get("", summary="List alerts")
async def list_alerts(
status: str | None = Query(None),
severity: str | None = Query(None),
analyzer: str | None = Query(None),
hunt_id: str | None = Query(None),
dataset_id: str | None = Query(None),
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Alert)
count_stmt = select(func.count(Alert.id))
if status:
stmt = stmt.where(Alert.status == status)
count_stmt = count_stmt.where(Alert.status == status)
if severity:
stmt = stmt.where(Alert.severity == severity)
count_stmt = count_stmt.where(Alert.severity == severity)
if analyzer:
stmt = stmt.where(Alert.analyzer == analyzer)
count_stmt = count_stmt.where(Alert.analyzer == analyzer)
if hunt_id:
stmt = stmt.where(Alert.hunt_id == hunt_id)
count_stmt = count_stmt.where(Alert.hunt_id == hunt_id)
if dataset_id:
stmt = stmt.where(Alert.dataset_id == dataset_id)
count_stmt = count_stmt.where(Alert.dataset_id == dataset_id)
total = (await db.execute(count_stmt)).scalar() or 0
results = (await db.execute(
stmt.order_by(desc(Alert.score), desc(Alert.created_at)).offset(offset).limit(limit)
)).scalars().all()
return {"alerts": [_alert_to_dict(a) for a in results], "total": total}
@router.get("/stats", summary="Alert statistics dashboard")
async def alert_stats(
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
"""Return aggregated alert statistics."""
base = select(Alert)
if hunt_id:
base = base.where(Alert.hunt_id == hunt_id)
# Severity breakdown
sev_stmt = select(Alert.severity, func.count(Alert.id)).group_by(Alert.severity)
if hunt_id:
sev_stmt = sev_stmt.where(Alert.hunt_id == hunt_id)
sev_rows = (await db.execute(sev_stmt)).all()
severity_counts = {s: c for s, c in sev_rows}
# Status breakdown
status_stmt = select(Alert.status, func.count(Alert.id)).group_by(Alert.status)
if hunt_id:
status_stmt = status_stmt.where(Alert.hunt_id == hunt_id)
status_rows = (await db.execute(status_stmt)).all()
status_counts = {s: c for s, c in status_rows}
# Analyzer breakdown
analyzer_stmt = select(Alert.analyzer, func.count(Alert.id)).group_by(Alert.analyzer)
if hunt_id:
analyzer_stmt = analyzer_stmt.where(Alert.hunt_id == hunt_id)
analyzer_rows = (await db.execute(analyzer_stmt)).all()
analyzer_counts = {a: c for a, c in analyzer_rows}
# Top MITRE techniques
mitre_stmt = (
select(Alert.mitre_technique, func.count(Alert.id))
.where(Alert.mitre_technique.isnot(None))
.group_by(Alert.mitre_technique)
.order_by(desc(func.count(Alert.id)))
.limit(10)
)
if hunt_id:
mitre_stmt = mitre_stmt.where(Alert.hunt_id == hunt_id)
mitre_rows = (await db.execute(mitre_stmt)).all()
top_mitre = [{"technique": t, "count": c} for t, c in mitre_rows]
total = sum(severity_counts.values())
return {
"total": total,
"severity_counts": severity_counts,
"status_counts": status_counts,
"analyzer_counts": analyzer_counts,
"top_mitre": top_mitre,
}
@router.get("/{alert_id}", summary="Get alert detail")
async def get_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
result = await db.get(Alert, alert_id)
if not result:
raise HTTPException(status_code=404, detail="Alert not found")
return _alert_to_dict(result)
@router.put("/{alert_id}", summary="Update alert (status, assignee, etc.)")
async def update_alert(
alert_id: str, body: AlertUpdate, db: AsyncSession = Depends(get_db)
):
alert = await db.get(Alert, alert_id)
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
if body.status is not None:
alert.status = body.status
if body.status == "acknowledged" and not alert.acknowledged_at:
alert.acknowledged_at = _utcnow()
if body.status in ("resolved", "false-positive") and not alert.resolved_at:
alert.resolved_at = _utcnow()
if body.severity is not None:
alert.severity = body.severity
if body.assignee is not None:
alert.assignee = body.assignee
if body.case_id is not None:
alert.case_id = body.case_id
if body.tags is not None:
alert.tags = body.tags
await db.commit()
await db.refresh(alert)
return _alert_to_dict(alert)
@router.delete("/{alert_id}", summary="Delete alert")
async def delete_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
alert = await db.get(Alert, alert_id)
if not alert:
raise HTTPException(status_code=404, detail="Alert not found")
await db.delete(alert)
await db.commit()
return {"ok": True}
# ── Bulk operations ──────────────────────────────────────────────────
@router.post("/bulk-update", summary="Bulk update alert statuses")
async def bulk_update_alerts(
alert_ids: list[str],
status: str = Query(...),
db: AsyncSession = Depends(get_db),
):
updated = 0
for aid in alert_ids:
alert = await db.get(Alert, aid)
if alert:
alert.status = status
if status == "acknowledged" and not alert.acknowledged_at:
alert.acknowledged_at = _utcnow()
if status in ("resolved", "false-positive") and not alert.resolved_at:
alert.resolved_at = _utcnow()
updated += 1
await db.commit()
return {"updated": updated}
# ── Run Analyzers ────────────────────────────────────────────────────
@router.get("/analyzers/list", summary="List available analyzers")
async def list_analyzers():
return {"analyzers": get_available_analyzers()}
@router.post("/analyze", summary="Run analyzers on a dataset/hunt and optionally create alerts")
async def run_analysis(
request: AnalyzeRequest, db: AsyncSession = Depends(get_db)
):
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
# Load rows
rows_objs = await _fetch_rows(
db, dataset_id=request.dataset_id, hunt_id=request.hunt_id, limit=10000,
)
if not rows_objs:
raise HTTPException(status_code=404, detail="No rows found")
rows = [r.normalized_data or r.data for r in rows_objs]
# Run analyzers
candidates = await run_all_analyzers(rows, enabled=request.analyzers, config=request.config)
created_alerts: list[dict] = []
if request.auto_create and candidates:
for c in candidates:
alert = Alert(
id=_new_id(),
title=c.title,
description=c.description,
severity=c.severity,
analyzer=c.analyzer,
score=c.score,
evidence=c.evidence,
mitre_technique=c.mitre_technique,
tags=c.tags,
hunt_id=request.hunt_id,
dataset_id=request.dataset_id,
)
db.add(alert)
created_alerts.append(_alert_to_dict(alert))
await db.commit()
return {
"candidates_found": len(candidates),
"alerts_created": len(created_alerts),
"alerts": created_alerts,
"summary": {
"by_severity": _count_by(candidates, "severity"),
"by_analyzer": _count_by(candidates, "analyzer"),
"rows_analyzed": len(rows),
},
}
def _count_by(items: list[AlertCandidate], attr: str) -> dict[str, int]:
counts: dict[str, int] = {}
for item in items:
key = getattr(item, attr, "unknown")
counts[key] = counts.get(key, 0) + 1
return counts
# ── Alert Rules CRUD ─────────────────────────────────────────────────
@router.get("/rules/list", summary="List alert rules")
async def list_rules(
enabled: bool | None = Query(None),
db: AsyncSession = Depends(get_db),
):
stmt = select(AlertRule)
if enabled is not None:
stmt = stmt.where(AlertRule.enabled == enabled)
results = (await db.execute(stmt.order_by(AlertRule.created_at))).scalars().all()
return {"rules": [_rule_to_dict(r) for r in results]}
@router.post("/rules", summary="Create alert rule")
async def create_rule(body: RuleCreate, db: AsyncSession = Depends(get_db)):
# Validate analyzer exists
if not get_analyzer(body.analyzer):
raise HTTPException(status_code=400, detail=f"Unknown analyzer: {body.analyzer}")
rule = AlertRule(
id=_new_id(),
name=body.name,
description=body.description,
analyzer=body.analyzer,
config=body.config,
severity_override=body.severity_override,
enabled=body.enabled,
hunt_id=body.hunt_id,
)
db.add(rule)
await db.commit()
await db.refresh(rule)
return _rule_to_dict(rule)
@router.put("/rules/{rule_id}", summary="Update alert rule")
async def update_rule(
rule_id: str, body: RuleUpdate, db: AsyncSession = Depends(get_db)
):
rule = await db.get(AlertRule, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Rule not found")
if body.name is not None:
rule.name = body.name
if body.description is not None:
rule.description = body.description
if body.config is not None:
rule.config = body.config
if body.severity_override is not None:
rule.severity_override = body.severity_override
if body.enabled is not None:
rule.enabled = body.enabled
await db.commit()
await db.refresh(rule)
return _rule_to_dict(rule)
@router.delete("/rules/{rule_id}", summary="Delete alert rule")
async def delete_rule(rule_id: str, db: AsyncSession = Depends(get_db)):
rule = await db.get(AlertRule, rule_id)
if not rule:
raise HTTPException(status_code=404, detail="Rule not found")
await db.delete(rule)
await db.commit()
return {"ok": True}

View File

@@ -0,0 +1,336 @@
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
import logging
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.repositories.datasets import DatasetRepository
from app.services.process_tree import (
build_process_tree,
build_storyline,
compute_risk_scores,
_fetch_rows,
)
from app.services.llm_analysis import (
AnalysisRequest,
AnalysisResult,
run_llm_analysis,
)
from app.services.timeline import (
build_timeline_bins,
compute_field_stats,
search_rows,
)
from app.services.mitre import (
map_to_attack,
build_knowledge_graph,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
# ── Response models ───────────────────────────────────────────────────
class ProcessTreeResponse(BaseModel):
trees: list[dict] = Field(default_factory=list)
total_processes: int = 0
class StorylineResponse(BaseModel):
nodes: list[dict] = Field(default_factory=list)
edges: list[dict] = Field(default_factory=list)
summary: dict = Field(default_factory=dict)
class RiskHostEntry(BaseModel):
hostname: str
score: int = 0
signals: list[str] = Field(default_factory=list)
event_count: int = 0
process_count: int = 0
network_count: int = 0
file_count: int = 0
class RiskSummaryResponse(BaseModel):
hosts: list[RiskHostEntry] = Field(default_factory=list)
overall_score: int = 0
total_events: int = 0
severity_breakdown: dict[str, int] = Field(default_factory=dict)
# ── Routes ────────────────────────────────────────────────────────────
@router.get(
"/process-tree",
response_model=ProcessTreeResponse,
summary="Build process tree from dataset rows",
description=(
"Extracts parent→child process relationships from dataset rows "
"and returns a hierarchical forest of process nodes."
),
)
async def get_process_tree(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
"""Return process tree(s) for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
trees = await build_process_tree(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
# Count total processes recursively
def _count(node: dict) -> int:
return 1 + sum(_count(c) for c in node.get("children", []))
total = sum(_count(t) for t in trees)
return ProcessTreeResponse(trees=trees, total_processes=total)
@router.get(
"/storyline",
response_model=StorylineResponse,
summary="Build CrowdStrike-style storyline attack graph",
description=(
"Creates a Cytoscape-compatible graph of events connected by "
"process lineage (spawned) and temporal sequence within each host."
),
)
async def get_storyline(
dataset_id: str | None = Query(None, description="Dataset ID"),
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
hostname: str | None = Query(None, description="Filter by hostname"),
db: AsyncSession = Depends(get_db),
):
"""Return a storyline graph for a dataset or hunt."""
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
result = await build_storyline(
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
)
return StorylineResponse(**result)
@router.get(
"/risk-summary",
response_model=RiskSummaryResponse,
summary="Compute risk scores per host",
description=(
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
"credential dumping, lateral movement) and produces per-host risk scores."
),
)
async def get_risk_summary(
hunt_id: str | None = Query(None, description="Hunt ID"),
db: AsyncSession = Depends(get_db),
):
"""Return risk scores for all hosts in a hunt."""
result = await compute_risk_scores(db, hunt_id=hunt_id)
return RiskSummaryResponse(**result)
# ── LLM Analysis ─────────────────────────────────────────────────────
@router.post(
"/llm-analyze",
response_model=AnalysisResult,
summary="Run LLM-powered threat analysis on dataset",
description=(
"Loads dataset rows server-side, builds a summary, and sends to "
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
),
)
async def llm_analyze(
request: AnalysisRequest,
db: AsyncSession = Depends(get_db),
):
"""Run LLM analysis on a dataset or hunt."""
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
# Load rows
rows_objs = await _fetch_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
limit=2000,
)
if not rows_objs:
raise HTTPException(status_code=404, detail="No rows found for analysis")
# Extract data dicts
rows = [r.normalized_data or r.data for r in rows_objs]
# Get dataset name
ds_name = "hunt datasets"
if request.dataset_id:
repo = DatasetRepository(db)
ds = await repo.get_dataset(request.dataset_id)
if ds:
ds_name = ds.name
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
return result
# ── Timeline ──────────────────────────────────────────────────────────
@router.get(
"/timeline",
summary="Get event timeline histogram bins",
)
async def get_timeline(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
bins: int = Query(60, ge=10, le=200),
db: AsyncSession = Depends(get_db),
):
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
@router.get(
"/field-stats",
summary="Get per-field value distributions",
)
async def get_field_stats(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
fields: str | None = Query(None, description="Comma-separated field names"),
top_n: int = Query(20, ge=5, le=100),
db: AsyncSession = Depends(get_db),
):
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
field_list = [f.strip() for f in fields.split(",")] if fields else None
return await compute_field_stats(
db, dataset_id=dataset_id, hunt_id=hunt_id,
fields=field_list, top_n=top_n,
)
class SearchRequest(BaseModel):
dataset_id: Optional[str] = None
hunt_id: Optional[str] = None
query: str = ""
filters: dict[str, str] = Field(default_factory=dict)
time_start: Optional[str] = None
time_end: Optional[str] = None
limit: int = 500
offset: int = 0
@router.post(
"/search",
summary="Search and filter dataset rows",
)
async def search_dataset_rows(
request: SearchRequest,
db: AsyncSession = Depends(get_db),
):
if not request.dataset_id and not request.hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await search_rows(
db,
dataset_id=request.dataset_id,
hunt_id=request.hunt_id,
query=request.query,
filters=request.filters,
time_start=request.time_start,
time_end=request.time_end,
limit=request.limit,
offset=request.offset,
)
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
@router.get(
"/mitre-map",
summary="Map dataset events to MITRE ATT&CK techniques",
)
async def get_mitre_map(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
@router.get(
"/knowledge-graph",
summary="Build entity-technique knowledge graph",
)
async def get_knowledge_graph(
dataset_id: str | None = Query(None),
hunt_id: str | None = Query(None),
db: AsyncSession = Depends(get_db),
):
<<<<<<< HEAD
"""Submit a new job to the queue.
Job types: triage, host_profile, report, anomaly, query
Params vary by type (e.g., dataset_id, hunt_id, question, mode).
"""
from app.services.job_queue import job_queue, JobType
try:
jt = JobType(job_type)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
)
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
job = job_queue.submit(jt, **params)
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
# --- Load balancer status ---
@router.get("/lb/status")
async def lb_status():
"""Get load balancer status for both nodes."""
from app.services.load_balancer import lb
return lb.get_status()
@router.post("/lb/check")
async def lb_health_check():
"""Force a health check of both nodes."""
from app.services.load_balancer import lb
await lb.check_health()
return lb.get_status()
=======
if not dataset_id and not hunt_id:
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -1,4 +1,4 @@
"""API routes for authentication register, login, refresh, profile.""" """API routes for authentication — register, login, refresh, profile."""
import logging import logging
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
# ── Request / Response models ───────────────────────────────────────── # ── Request / Response models ─────────────────────────────────────────
class RegisterRequest(BaseModel): class RegisterRequest(BaseModel):
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
tokens: TokenPair tokens: TokenPair
# ── Routes ──────────────────────────────────────────────────────────── # ── Routes ────────────────────────────────────────────────────────────
@router.post( @router.post(
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
user = User( user = User(
username=body.username, username=body.username,
email=body.email, email=body.email,
password_hash=hash_password(body.password), hashed_password=hash_password(body.password),
display_name=body.display_name or body.username, display_name=body.display_name or body.username,
role="analyst", # Default role role="analyst", # Default role
) )
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.username == body.username)) result = await db.execute(select(User).where(User.username == body.username))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user or not user.password_hash: if not user or not user.hashed_password:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password", detail="Invalid username or password",
) )
if not verify_password(body.password, user.password_hash): if not verify_password(body.password, user.hashed_password):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password", detail="Invalid username or password",
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
if token_data.type != "refresh": if token_data.type != "refresh":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type use refresh token", detail="Invalid token type — use refresh token",
) )
result = await db.execute(select(User).where(User.id == token_data.sub)) result = await db.execute(select(User).where(User.id == token_data.sub))
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
is_active=user.is_active, is_active=user.is_active,
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at), created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
) )

View File

@@ -0,0 +1,296 @@
"""API routes for case management — CRUD for cases, tasks, and activity logs."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Case, CaseTask, ActivityLog, _new_id, _utcnow
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/cases", tags=["cases"])
# ── Pydantic models ──────────────────────────────────────────────────
class CaseCreate(BaseModel):
title: str
description: Optional[str] = None
severity: str = "medium"
tlp: str = "amber"
pap: str = "amber"
priority: int = 2
assignee: Optional[str] = None
tags: Optional[list[str]] = None
hunt_id: Optional[str] = None
mitre_techniques: Optional[list[str]] = None
iocs: Optional[list[dict]] = None
class CaseUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
severity: Optional[str] = None
tlp: Optional[str] = None
pap: Optional[str] = None
status: Optional[str] = None
priority: Optional[int] = None
assignee: Optional[str] = None
tags: Optional[list[str]] = None
mitre_techniques: Optional[list[str]] = None
iocs: Optional[list[dict]] = None
class TaskCreate(BaseModel):
title: str
description: Optional[str] = None
assignee: Optional[str] = None
class TaskUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
status: Optional[str] = None
assignee: Optional[str] = None
order: Optional[int] = None
# ── Helper: log activity ─────────────────────────────────────────────
async def _log_activity(
db: AsyncSession,
entity_type: str,
entity_id: str,
action: str,
details: dict | None = None,
):
log = ActivityLog(
entity_type=entity_type,
entity_id=entity_id,
action=action,
details=details,
created_at=_utcnow(),
)
db.add(log)
# ── Case CRUD ─────────────────────────────────────────────────────────
@router.post("", summary="Create a case")
async def create_case(body: CaseCreate, db: AsyncSession = Depends(get_db)):
now = _utcnow()
case = Case(
id=_new_id(),
title=body.title,
description=body.description,
severity=body.severity,
tlp=body.tlp,
pap=body.pap,
priority=body.priority,
assignee=body.assignee,
tags=body.tags,
hunt_id=body.hunt_id,
mitre_techniques=body.mitre_techniques,
iocs=body.iocs,
created_at=now,
updated_at=now,
)
db.add(case)
await _log_activity(db, "case", case.id, "created", {"title": body.title})
await db.commit()
await db.refresh(case)
return _case_to_dict(case)
@router.get("", summary="List cases")
async def list_cases(
status: Optional[str] = Query(None),
hunt_id: Optional[str] = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
q = select(Case).order_by(desc(Case.updated_at))
if status:
q = q.where(Case.status == status)
if hunt_id:
q = q.where(Case.hunt_id == hunt_id)
q = q.offset(offset).limit(limit)
result = await db.execute(q)
cases = result.scalars().all()
count_q = select(func.count(Case.id))
if status:
count_q = count_q.where(Case.status == status)
if hunt_id:
count_q = count_q.where(Case.hunt_id == hunt_id)
total = (await db.execute(count_q)).scalar() or 0
return {"cases": [_case_to_dict(c) for c in cases], "total": total}
@router.get("/{case_id}", summary="Get case detail")
async def get_case(case_id: str, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
return _case_to_dict(case)
@router.put("/{case_id}", summary="Update a case")
async def update_case(case_id: str, body: CaseUpdate, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
changes = {}
for field in ["title", "description", "severity", "tlp", "pap", "status",
"priority", "assignee", "tags", "mitre_techniques", "iocs"]:
val = getattr(body, field)
if val is not None:
old = getattr(case, field)
setattr(case, field, val)
changes[field] = {"old": old, "new": val}
if "status" in changes and changes["status"]["new"] == "in-progress" and not case.started_at:
case.started_at = _utcnow()
if "status" in changes and changes["status"]["new"] in ("resolved", "closed"):
case.resolved_at = _utcnow()
case.updated_at = _utcnow()
await _log_activity(db, "case", case.id, "updated", changes)
await db.commit()
await db.refresh(case)
return _case_to_dict(case)
@router.delete("/{case_id}", summary="Delete a case")
async def delete_case(case_id: str, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
await db.delete(case)
await db.commit()
return {"deleted": True}
# ── Task CRUD ─────────────────────────────────────────────────────────
@router.post("/{case_id}/tasks", summary="Add task to case")
async def create_task(case_id: str, body: TaskCreate, db: AsyncSession = Depends(get_db)):
case = await db.get(Case, case_id)
if not case:
raise HTTPException(status_code=404, detail="Case not found")
now = _utcnow()
task = CaseTask(
id=_new_id(),
case_id=case_id,
title=body.title,
description=body.description,
assignee=body.assignee,
created_at=now,
updated_at=now,
)
db.add(task)
await _log_activity(db, "case", case_id, "task_created", {"title": body.title})
await db.commit()
await db.refresh(task)
return _task_to_dict(task)
@router.put("/{case_id}/tasks/{task_id}", summary="Update a task")
async def update_task(case_id: str, task_id: str, body: TaskUpdate, db: AsyncSession = Depends(get_db)):
task = await db.get(CaseTask, task_id)
if not task or task.case_id != case_id:
raise HTTPException(status_code=404, detail="Task not found")
for field in ["title", "description", "status", "assignee", "order"]:
val = getattr(body, field)
if val is not None:
setattr(task, field, val)
task.updated_at = _utcnow()
await _log_activity(db, "case", case_id, "task_updated", {"task_id": task_id})
await db.commit()
await db.refresh(task)
return _task_to_dict(task)
@router.delete("/{case_id}/tasks/{task_id}", summary="Delete a task")
async def delete_task(case_id: str, task_id: str, db: AsyncSession = Depends(get_db)):
task = await db.get(CaseTask, task_id)
if not task or task.case_id != case_id:
raise HTTPException(status_code=404, detail="Task not found")
await db.delete(task)
await db.commit()
return {"deleted": True}
# ── Activity Log ──────────────────────────────────────────────────────
@router.get("/{case_id}/activity", summary="Get case activity log")
async def get_activity(
case_id: str,
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
):
q = (
select(ActivityLog)
.where(ActivityLog.entity_type == "case", ActivityLog.entity_id == case_id)
.order_by(desc(ActivityLog.created_at))
.limit(limit)
)
result = await db.execute(q)
logs = result.scalars().all()
return {
"logs": [
{
"id": l.id,
"action": l.action,
"details": l.details,
"user_id": l.user_id,
"created_at": l.created_at.isoformat() if l.created_at else None,
}
for l in logs
]
}
# ── Helpers ───────────────────────────────────────────────────────────
def _case_to_dict(c: Case) -> dict:
return {
"id": c.id,
"title": c.title,
"description": c.description,
"severity": c.severity,
"tlp": c.tlp,
"pap": c.pap,
"status": c.status,
"priority": c.priority,
"assignee": c.assignee,
"tags": c.tags or [],
"hunt_id": c.hunt_id,
"owner_id": c.owner_id,
"mitre_techniques": c.mitre_techniques or [],
"iocs": c.iocs or [],
"started_at": c.started_at.isoformat() if c.started_at else None,
"resolved_at": c.resolved_at.isoformat() if c.resolved_at else None,
"created_at": c.created_at.isoformat() if c.created_at else None,
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
"tasks": [_task_to_dict(t) for t in (c.tasks or [])],
}
def _task_to_dict(t: CaseTask) -> dict:
return {
"id": t.id,
"case_id": t.case_id,
"title": t.title,
"description": t.description,
"status": t.status,
"assignee": t.assignee,
"order": t.order,
"created_at": t.created_at.isoformat() if t.created_at else None,
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
}

View File

@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings from app.config import settings
from app.db import get_db from app.db import get_db
from app.db.models import ProcessingTask
from app.db.repositories.datasets import DatasetRepository from app.db.repositories.datasets import DatasetRepository
from app.services.csv_parser import parse_csv_bytes, infer_column_types from app.services.csv_parser import parse_csv_bytes, infer_column_types
from app.services.normalizer import ( from app.services.normalizer import (
@@ -18,15 +19,20 @@ from app.services.normalizer import (
detect_ioc_columns, detect_ioc_columns,
detect_time_range, detect_time_range,
) )
from app.services.artifact_classifier import classify_artifact, get_artifact_category
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from app.services.job_queue import job_queue, JobType
from app.services.host_inventory import inventory_cache
from app.services.scanner import keyword_scan_cache
router = APIRouter(prefix="/api/datasets", tags=["datasets"]) router = APIRouter(prefix="/api/datasets", tags=["datasets"])
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"} ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
# ── Response models ─────────────────────────────────────────────────── # -- Response models --
class DatasetSummary(BaseModel): class DatasetSummary(BaseModel):
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
delimiter: str | None = None delimiter: str | None = None
time_range_start: str | None = None time_range_start: str | None = None
time_range_end: str | None = None time_range_end: str | None = None
artifact_type: str | None = None
processing_status: str | None = None
hunt_id: str | None = None hunt_id: str | None = None
created_at: str created_at: str
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
column_types: dict column_types: dict
normalized_columns: dict normalized_columns: dict
ioc_columns: dict ioc_columns: dict
artifact_type: str | None = None
processing_status: str
jobs_queued: list[str]
message: str message: str
# ── Routes ──────────────────────────────────────────────────────────── # -- Routes --
@router.post( @router.post(
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
response_model=UploadResponse, response_model=UploadResponse,
summary="Upload a CSV dataset", summary="Upload a CSV dataset",
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, " description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
"IOCs auto-detected, and rows stored in the database.", "IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
) )
async def upload_dataset( async def upload_dataset(
file: UploadFile = File(...), file: UploadFile = File(...),
@@ -87,7 +98,7 @@ async def upload_dataset(
hunt_id: str | None = Query(None, description="Hunt ID to associate with"), hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Upload and parse a CSV dataset.""" """Upload and parse a CSV dataset, then trigger full processing pipeline."""
# Validate file # Validate file
if not file.filename: if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided") raise HTTPException(status_code=400, detail="No filename provided")
@@ -136,7 +147,12 @@ async def upload_dataset(
# Detect time range # Detect time range
time_start, time_end = detect_time_range(rows, column_mapping) time_start, time_end = detect_time_range(rows, column_mapping)
# Store in DB # Classify artifact type from column headers
artifact_type = classify_artifact(columns)
artifact_category = get_artifact_category(artifact_type)
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
# Store in DB with processing_status = "processing"
repo = DatasetRepository(db) repo = DatasetRepository(db)
dataset = await repo.create_dataset( dataset = await repo.create_dataset(
name=name or Path(file.filename).stem, name=name or Path(file.filename).stem,
@@ -152,6 +168,8 @@ async def upload_dataset(
time_range_start=time_start, time_range_start=time_start,
time_range_end=time_end, time_range_end=time_end,
hunt_id=hunt_id, hunt_id=hunt_id,
artifact_type=artifact_type,
processing_status="processing",
) )
await repo.bulk_insert_rows( await repo.bulk_insert_rows(
@@ -162,9 +180,88 @@ async def upload_dataset(
logger.info( logger.info(
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, " f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected" f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
f"artifact={artifact_type}"
) )
# -- Queue full processing pipeline --
jobs_queued = []
task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
return UploadResponse( return UploadResponse(
id=dataset.id, id=dataset.id,
name=dataset.name, name=dataset.name,
@@ -173,7 +270,10 @@ async def upload_dataset(
column_types=column_types, column_types=column_types,
normalized_columns=column_mapping, normalized_columns=column_mapping,
ioc_columns=ioc_columns, ioc_columns=ioc_columns,
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected", artifact_type=artifact_type,
processing_status="processing",
jobs_queued=jobs_queued,
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
) )
@@ -208,6 +308,8 @@ async def list_datasets(
delimiter=ds.delimiter, delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id, hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(), created_at=ds.created_at.isoformat(),
) )
@@ -244,6 +346,8 @@ async def get_dataset(
delimiter=ds.delimiter, delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None, time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None, time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id, hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(), created_at=ds.created_at.isoformat(),
) )
@@ -292,4 +396,32 @@ async def delete_dataset(
deleted = await repo.delete_dataset(dataset_id) deleted = await repo.delete_dataset(dataset_id)
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found") raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id} return {"message": "Dataset deleted", "id": dataset_id}
@router.post(
"/rescan-ioc",
summary="Re-scan IOC columns for all datasets",
)
async def rescan_ioc_columns(
db: AsyncSession = Depends(get_db),
):
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
repo = DatasetRepository(db)
all_ds = await repo.list_datasets(limit=10000)
updated = 0
for ds in all_ds:
columns = list((ds.column_schema or {}).keys())
if not columns:
continue
new_ioc = detect_ioc_columns(
columns,
ds.column_schema or {},
ds.normalized_columns or {},
)
if new_ioc != (ds.ioc_columns or {}):
ds.ioc_columns = new_ioc
updated += 1
await db.commit()
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}

View File

@@ -8,16 +8,15 @@ from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db from app.db import get_db
from app.db.models import Hunt, Conversation, Message from app.db.models import Hunt, Dataset, ProcessingTask
from app.services.job_queue import job_queue
from app.services.host_inventory import inventory_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/hunts", tags=["hunts"]) router = APIRouter(prefix="/api/hunts", tags=["hunts"])
# ── Models ────────────────────────────────────────────────────────────
class HuntCreate(BaseModel): class HuntCreate(BaseModel):
name: str = Field(..., max_length=256) name: str = Field(..., max_length=256)
description: str | None = None description: str | None = None
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
class HuntUpdate(BaseModel): class HuntUpdate(BaseModel):
name: str | None = None name: str | None = None
description: str | None = None description: str | None = None
status: str | None = None # active | closed | archived status: str | None = None
class HuntResponse(BaseModel): class HuntResponse(BaseModel):
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
total: int total: int
# ── Routes ──────────────────────────────────────────────────────────── class HuntProgressResponse(BaseModel):
hunt_id: str
status: str
progress_percent: float
dataset_total: int
dataset_completed: int
dataset_processing: int
dataset_errors: int
active_jobs: int
queued_jobs: int
network_status: str
stages: dict
@router.post("", response_model=HuntResponse, summary="Create a new hunt") @router.post("", response_model=HuntResponse, summary="Create a new hunt")
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
) )
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
hunt = await db.get(Hunt, hunt_id)
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
ds_rows = await db.execute(
select(Dataset.id, Dataset.processing_status)
.where(Dataset.hunt_id == hunt_id)
)
datasets = ds_rows.all()
dataset_ids = {row[0] for row in datasets}
dataset_total = len(datasets)
dataset_completed = sum(1 for _, st in datasets if st == "completed")
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
task_rows = await db.execute(
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
.where(ProcessingTask.hunt_id == hunt_id)
)
tasks = task_rows.all()
task_total = len(tasks)
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
task_running = sum(1 for _, st, _ in tasks if st == "running")
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
task_ratio = (task_done / task_total) if task_total > 0 else None
active_jobs = max(active_jobs_mem, task_running)
queued_jobs = max(queued_jobs_mem, task_queued)
stage_rollup: dict[str, dict] = {}
for stage, status, progress in tasks:
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
bucket["total"] += 1
if status in ("completed", "failed", "cancelled"):
bucket["done"] += 1
elif status == "running":
bucket["running"] += 1
elif status == "queued":
bucket["queued"] += 1
bucket["progress_sum"] += float(progress or 0.0)
for stage_name, bucket in stage_rollup.items():
total = max(1, bucket["total"])
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
if inventory_cache.get(hunt_id) is not None:
network_status = "ready"
network_ratio = 1.0
elif inventory_cache.is_building(hunt_id):
network_status = "building"
network_ratio = 0.5
else:
network_status = "none"
network_ratio = 0.0
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
if task_ratio is None:
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
else:
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
status = "ready"
if dataset_total == 0:
status = "idle"
elif progress_percent < 100:
status = "processing"
stages = {
"datasets": {
"total": dataset_total,
"completed": dataset_completed,
"processing": dataset_processing,
"errors": dataset_errors,
"percent": round(dataset_ratio * 100.0, 1),
},
"network": {
"status": network_status,
"percent": round(network_ratio * 100.0, 1),
},
"jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
"task_total": task_total,
"task_done": task_done,
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
},
"task_stages": stage_rollup,
}
return HuntProgressResponse(
hunt_id=hunt_id,
status=status,
progress_percent=progress_percent,
dataset_total=dataset_total,
dataset_completed=dataset_completed,
dataset_processing=dataset_processing,
dataset_errors=dataset_errors,
active_jobs=active_jobs,
queued_jobs=queued_jobs,
network_status=network_status,
stages=stages,
)
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt") @router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
async def update_hunt( async def update_hunt(
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db) hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)

View File

@@ -1,25 +1,21 @@
"""API routes for AUP keyword themes, keyword CRUD, and scanning.""" """API routes for AUP keyword themes, keyword CRUD, and scanning."""
import logging import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select, func, delete from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db from app.db import get_db
from app.db.models import KeywordTheme, Keyword from app.db.models import KeywordTheme, Keyword
from app.services.scanner import KeywordScanner from app.services.scanner import KeywordScanner, keyword_scan_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/keywords", tags=["keywords"]) router = APIRouter(prefix="/api/keywords", tags=["keywords"])
# ── Pydantic schemas ──────────────────────────────────────────────────
class ThemeCreate(BaseModel): class ThemeCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=128) name: str = Field(..., min_length=1, max_length=128)
color: str = Field(default="#9e9e9e", max_length=16) color: str = Field(default="#9e9e9e", max_length=16)
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
class ScanRequest(BaseModel): class ScanRequest(BaseModel):
dataset_ids: list[str] | None = None # None → all datasets dataset_ids: list[str] | None = None
theme_ids: list[str] | None = None # None → all enabled themes theme_ids: list[str] | None = None
scan_hunts: bool = True scan_hunts: bool = False
scan_annotations: bool = True scan_annotations: bool = False
scan_messages: bool = True scan_messages: bool = False
prefer_cache: bool = True
force_rescan: bool = False
class ScanHit(BaseModel): class ScanHit(BaseModel):
theme_name: str theme_name: str
theme_color: str theme_color: str
keyword: str keyword: str
source_type: str # dataset_row | hunt | annotation | message source_type: str
source_id: str | int source_id: str | int
field: str field: str
matched_value: str matched_value: str
row_index: int | None = None row_index: int | None = None
dataset_name: str | None = None dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
class ScanResponse(BaseModel): class ScanResponse(BaseModel):
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
themes_scanned: int themes_scanned: int
keywords_scanned: int keywords_scanned: int
rows_scanned: int rows_scanned: int
cache_used: bool = False
cache_status: str = "miss"
# ── Helpers ─────────────────────────────────────────────────────────── cached_at: str | None = None
def _theme_to_out(t: KeywordTheme) -> ThemeOut: def _theme_to_out(t: KeywordTheme) -> ThemeOut:
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
) )
# ── Theme CRUD ──────────────────────────────────────────────────────── def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
hits: list[dict] = []
total_rows = 0
cached_at: str | None = None
for entry in entries:
result = entry["result"]
total_rows += int(result.get("rows_scanned", 0) or 0)
if entry.get("built_at"):
if not cached_at or entry["built_at"] > cached_at:
cached_at = entry["built_at"]
for h in result.get("hits", []):
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
continue
hits.append(h)
return {
"total_hits": len(hits),
"hits": hits,
"rows_scanned": total_rows,
"cached_at": cached_at,
}
@router.get("/themes", response_model=ThemeListResponse) @router.get("/themes", response_model=ThemeListResponse)
async def list_themes(db: AsyncSession = Depends(get_db)): async def list_themes(db: AsyncSession = Depends(get_db)):
"""List all keyword themes with their keywords.""" result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
result = await db.execute(
select(KeywordTheme).order_by(KeywordTheme.name)
)
themes = result.scalars().all() themes = result.scalars().all()
return ThemeListResponse( return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
themes=[_theme_to_out(t) for t in themes],
total=len(themes),
)
@router.post("/themes", response_model=ThemeOut, status_code=201) @router.post("/themes", response_model=ThemeOut, status_code=201)
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)): async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
"""Create a new keyword theme.""" exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
exists = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
)
if exists: if exists:
raise HTTPException(409, f"Theme '{body.name}' already exists") raise HTTPException(409, f"Theme '{body.name}' already exists")
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled) theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
db.add(theme) db.add(theme)
await db.flush() await db.flush()
await db.refresh(theme) await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme) return _theme_to_out(theme)
@router.put("/themes/{theme_id}", response_model=ThemeOut) @router.put("/themes/{theme_id}", response_model=ThemeOut)
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)): async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
"""Update theme name, color, or enabled status."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
if body.name is not None: if body.name is not None:
# check uniqueness
dup = await db.scalar( dup = await db.scalar(
select(KeywordTheme.id).where( select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
)
) )
if dup: if dup:
raise HTTPException(409, f"Theme '{body.name}' already exists") raise HTTPException(409, f"Theme '{body.name}' already exists")
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
theme.enabled = body.enabled theme.enabled = body.enabled
await db.flush() await db.flush()
await db.refresh(theme) await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme) return _theme_to_out(theme)
@router.delete("/themes/{theme_id}", status_code=204) @router.delete("/themes/{theme_id}", status_code=204)
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)): async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a theme and all its keywords."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
await db.delete(theme) await db.delete(theme)
keyword_scan_cache.clear()
# ── Keyword CRUD ──────────────────────────────────────────────────────
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201) @router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)): async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
"""Add a single keyword to a theme."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
db.add(kw) db.add(kw)
await db.flush() await db.flush()
await db.refresh(kw) await db.refresh(kw)
keyword_scan_cache.clear()
return KeywordOut( return KeywordOut(
id=kw.id, theme_id=kw.theme_id, value=kw.value, id=kw.id, theme_id=kw.theme_id, value=kw.value,
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(), is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201) @router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)): async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
"""Add multiple keywords to a theme at once."""
theme = await db.get(KeywordTheme, theme_id) theme = await db.get(KeywordTheme, theme_id)
if not theme: if not theme:
raise HTTPException(404, "Theme not found") raise HTTPException(404, "Theme not found")
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex)) db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
added += 1 added += 1
await db.flush() await db.flush()
keyword_scan_cache.clear()
return {"added": added, "theme_id": theme_id} return {"added": added, "theme_id": theme_id}
@router.delete("/keywords/{keyword_id}", status_code=204) @router.delete("/keywords/{keyword_id}", status_code=204)
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)): async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
"""Delete a single keyword."""
kw = await db.get(Keyword, keyword_id) kw = await db.get(Keyword, keyword_id)
if not kw: if not kw:
raise HTTPException(404, "Keyword not found") raise HTTPException(404, "Keyword not found")
await db.delete(kw) await db.delete(kw)
keyword_scan_cache.clear()
# ── Scan endpoints ────────────────────────────────────────────────────
@router.post("/scan", response_model=ScanResponse) @router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)): async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
"""Run AUP keyword scan across selected data sources."""
scanner = KeywordScanner(db) scanner = KeywordScanner(db)
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
return {
"total_hits": 0,
"hits": [],
"themes_scanned": 0,
"keywords_scanned": 0,
"rows_scanned": 0,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
can_use_cache = (
body.prefer_cache
and not body.force_rescan
and bool(body.dataset_ids)
and not body.scan_hunts
and not body.scan_annotations
and not body.scan_messages
)
if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
if missing:
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
merged = _merge_cached_results(
cached_entries + [{"result": partial, "built_at": None}],
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
result = await scanner.scan( result = await scanner.scan(
dataset_ids=body.dataset_ids, dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids, theme_ids=body.theme_ids,
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
scan_annotations=body.scan_annotations, scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages, scan_messages=body.scan_messages,
) )
return result
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
@router.get("/scan/quick", response_model=ScanResponse) @router.get("/scan/quick", response_model=ScanResponse)
@@ -251,7 +326,22 @@ async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"), dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Quick scan a single dataset with all enabled themes.""" entry = keyword_scan_cache.get(dataset_id)
if entry is not None:
result = entry.result
return {
**result,
"cache_used": True,
"cache_status": "hit",
"cached_at": entry.built_at,
}
scanner = KeywordScanner(db) scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id]) result = await scanner.scan(dataset_ids=[dataset_id])
return result keyword_scan_cache.put(dataset_id, result)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}

View File

@@ -0,0 +1,146 @@
"""API routes for MITRE ATT&CK coverage visualization."""
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
# Canonical MITRE ATT&CK tactics in kill-chain order
TACTICS = [
"Reconnaissance", "Resource Development", "Initial Access",
"Execution", "Persistence", "Privilege Escalation",
"Defense Evasion", "Credential Access", "Discovery",
"Lateral Movement", "Collection", "Command and Control",
"Exfiltration", "Impact",
]
# Simplified technique-to-tactic mapping (top techniques)
TECHNIQUE_TACTIC: dict[str, str] = {
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
"T1053": "Persistence", "T1053.005": "Persistence",
"T1547": "Persistence", "T1547.001": "Persistence",
"T1543": "Persistence", "T1543.003": "Persistence",
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
"T1070.004": "Defense Evasion",
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
"T1003": "Credential Access", "T1003.001": "Credential Access",
"T1110": "Credential Access", "T1558": "Credential Access",
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
"T1570": "Lateral Movement",
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
"T1071": "Command and Control", "T1071.001": "Command and Control",
"T1105": "Command and Control", "T1572": "Command and Control",
"T1095": "Command and Control",
"T1048": "Exfiltration", "T1041": "Exfiltration",
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
"T1566": "Initial Access", "T1566.001": "Initial Access",
"T1566.002": "Initial Access",
"T1190": "Initial Access", "T1133": "Initial Access",
"T1195": "Initial Access", "T1195.002": "Initial Access",
}
def _get_tactic(technique_id: str) -> str:
"""Map a technique ID to its tactic."""
tech = technique_id.strip().upper()
if tech in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[tech]
# Try parent technique
if "." in tech:
parent = tech.split(".")[0]
if parent in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[parent]
return "Unknown"
@router.get("/coverage")
async def get_mitre_coverage(
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
techniques: dict[str, dict] = {}
# Collect from triage results
triage_q = select(TriageResult)
if hunt_id:
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
result = await db.execute(triage_q.limit(500))
for t in result.scalars().all():
for tech in (t.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
# Collect from host profiles
profile_q = select(HostProfile)
if hunt_id:
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
result = await db.execute(profile_q.limit(200))
for p in result.scalars().all():
for tech in (p.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
# Collect from hypotheses
hyp_q = select(Hypothesis)
if hunt_id:
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
result = await db.execute(hyp_q.limit(200))
for h in result.scalars().all():
tech = h.mitre_technique
if tech:
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
# Build tactic-grouped response
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
tactic_groups["Unknown"] = []
for tech in techniques.values():
tactic = tech["tactic"]
if tactic not in tactic_groups:
tactic_groups[tactic] = []
tactic_groups[tactic].append(tech)
total_techniques = len(techniques)
total_detections = sum(t["count"] for t in techniques.values())
return {
"tactics": TACTICS,
"technique_count": total_techniques,
"detection_count": total_detections,
"tactic_coverage": {
t: {"techniques": techs, "count": len(techs)}
for t, techs in tactic_groups.items()
if techs
},
"all_techniques": list(techniques.values()),
}

View File

@@ -0,0 +1,244 @@
<<<<<<< HEAD
"""Network topology API - host inventory endpoint with background caching."""
=======
"""API routes for Network Picture — deduplicated host inventory."""
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
<<<<<<< HEAD
from fastapi.responses import JSONResponse
=======
from pydantic import BaseModel, Field
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
<<<<<<< HEAD
from app.services.host_inventory import build_host_inventory, inventory_cache
from app.services.job_queue import job_queue, JobType
=======
from app.services.network_inventory import build_network_picture
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/network", tags=["network"])
<<<<<<< HEAD
@router.get("/host-inventory")
async def get_host_inventory(
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
force: bool = Query(False, description="Force rebuild, ignoring cache"),
db: AsyncSession = Depends(get_db),
):
"""Return a deduplicated host inventory for the hunt.
Returns instantly from cache if available (pre-built after upload or on startup).
If cache is cold, triggers a background build and returns 202 so the
frontend can poll /inventory-status and re-request when ready.
"""
# Force rebuild: invalidate cache, queue background job, return 202
if force:
inventory_cache.invalidate(hunt_id)
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Rebuild queued"},
)
# Try cache first
cached = inventory_cache.get(hunt_id)
if cached is not None:
logger.info(f"Serving cached host inventory for {hunt_id}")
return cached
# Cache miss: trigger background build instead of blocking for 90+ seconds
if not inventory_cache.is_building(hunt_id):
logger.info(f"Cache miss for {hunt_id}, triggering background build")
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Inventory is being built in the background"},
)
def _build_summary(inv: dict, top_n: int = 20) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
return {
"stats": inv.get("stats", {}),
"top_hosts": [
{
"id": h.get("id"),
"hostname": h.get("hostname"),
"row_count": h.get("row_count", 0),
"ip_count": len(h.get("ips", [])),
"user_count": len(h.get("users", [])),
}
for h in top_hosts
],
"top_edges": top_edges,
}
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
if node_id:
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
ids = {node_id}
for c in rel_edges:
ids.add(c.get("source"))
ids.add(c.get("target"))
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
else:
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
allowed = {h.get("id") for h in rel_hosts}
rel_edges = [
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
if c.get("source") in allowed and c.get("target") in allowed
][:max_edges]
return {
"hosts": rel_hosts,
"connections": rel_edges,
"stats": {
**inv.get("stats", {}),
"subgraph_hosts": len(rel_hosts),
"subgraph_connections": len(rel_edges),
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
},
}
@router.get("/summary")
async def get_inventory_summary(
hunt_id: str = Query(..., description="Hunt ID"),
top_n: int = Query(20, ge=1, le=200),
):
"""Return a lightweight summary view for large hunts."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_summary(cached, top_n=top_n)
@router.get("/subgraph")
async def get_inventory_subgraph(
hunt_id: str = Query(..., description="Hunt ID"),
node_id: str | None = Query(None, description="Optional focal node"),
max_hosts: int = Query(200, ge=1, le=5000),
max_edges: int = Query(1500, ge=1, le=20000),
):
"""Return a bounded subgraph for scale-safe rendering."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
@router.get("/inventory-status")
async def get_inventory_status(
hunt_id: str = Query(..., description="Hunt ID to check"),
):
"""Check whether pre-computed host inventory is ready for a hunt.
Returns: { status: "ready" | "building" | "none" }
"""
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
@router.post("/rebuild-inventory")
async def trigger_rebuild(
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
):
"""Trigger a background rebuild of the host inventory cache."""
inventory_cache.invalidate(hunt_id)
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return {"job_id": job.id, "status": "queued"}
=======
# ── Response models ───────────────────────────────────────────────────
class HostEntry(BaseModel):
hostname: str
ips: list[str] = Field(default_factory=list)
users: list[str] = Field(default_factory=list)
os: list[str] = Field(default_factory=list)
mac_addresses: list[str] = Field(default_factory=list)
protocols: list[str] = Field(default_factory=list)
open_ports: list[str] = Field(default_factory=list)
remote_targets: list[str] = Field(default_factory=list)
datasets: list[str] = Field(default_factory=list)
connection_count: int = 0
first_seen: str | None = None
last_seen: str | None = None
class PictureSummary(BaseModel):
total_hosts: int = 0
total_connections: int = 0
total_unique_ips: int = 0
datasets_scanned: int = 0
class NetworkPictureResponse(BaseModel):
hosts: list[HostEntry]
summary: PictureSummary
# ── Routes ────────────────────────────────────────────────────────────
@router.get(
"/picture",
response_model=NetworkPictureResponse,
summary="Build deduplicated host inventory for a hunt",
description=(
"Scans all datasets in the specified hunt, extracts host-identifying "
"fields (hostname, IP, username, OS, MAC, ports), deduplicates by "
"hostname, and returns a clean one-row-per-host network picture."
),
)
async def get_network_picture(
hunt_id: str = Query(..., description="Hunt ID to scan"),
db: AsyncSession = Depends(get_db),
):
"""Return a deduplicated network picture for a hunt."""
if not hunt_id:
raise HTTPException(status_code=400, detail="hunt_id is required")
result = await build_network_picture(db, hunt_id)
return result
>>>>>>> 7c454036c7ef6a3d6517f98cbee643fd0238e0b2

View File

@@ -0,0 +1,360 @@
"""API routes for investigation notebooks and playbooks."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Notebook, PlaybookRun, _new_id, _utcnow
from app.services.playbook import (
get_builtin_playbooks,
get_playbook_template,
validate_notebook_cells,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/notebooks", tags=["notebooks"])
# ── Pydantic models ──────────────────────────────────────────────────
class NotebookCreate(BaseModel):
title: str
description: Optional[str] = None
cells: Optional[list[dict]] = None
hunt_id: Optional[str] = None
case_id: Optional[str] = None
tags: Optional[list[str]] = None
class NotebookUpdate(BaseModel):
title: Optional[str] = None
description: Optional[str] = None
cells: Optional[list[dict]] = None
tags: Optional[list[str]] = None
class CellUpdate(BaseModel):
"""Update a single cell or add a new one."""
cell_id: str
cell_type: Optional[str] = None
source: Optional[str] = None
output: Optional[str] = None
metadata: Optional[dict] = None
class PlaybookStart(BaseModel):
playbook_name: str
hunt_id: Optional[str] = None
case_id: Optional[str] = None
started_by: Optional[str] = None
class StepComplete(BaseModel):
notes: Optional[str] = None
status: str = "completed" # completed | skipped
# ── Helpers ───────────────────────────────────────────────────────────
def _notebook_to_dict(nb: Notebook) -> dict:
return {
"id": nb.id,
"title": nb.title,
"description": nb.description,
"cells": nb.cells or [],
"hunt_id": nb.hunt_id,
"case_id": nb.case_id,
"owner_id": nb.owner_id,
"tags": nb.tags or [],
"cell_count": len(nb.cells or []),
"created_at": nb.created_at.isoformat() if nb.created_at else None,
"updated_at": nb.updated_at.isoformat() if nb.updated_at else None,
}
def _run_to_dict(run: PlaybookRun) -> dict:
return {
"id": run.id,
"playbook_name": run.playbook_name,
"status": run.status,
"current_step": run.current_step,
"total_steps": run.total_steps,
"step_results": run.step_results or [],
"hunt_id": run.hunt_id,
"case_id": run.case_id,
"started_by": run.started_by,
"created_at": run.created_at.isoformat() if run.created_at else None,
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
"completed_at": run.completed_at.isoformat() if run.completed_at else None,
}
# ── Notebook CRUD ─────────────────────────────────────────────────────
@router.get("", summary="List notebooks")
async def list_notebooks(
hunt_id: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
stmt = select(Notebook)
count_stmt = select(func.count(Notebook.id))
if hunt_id:
stmt = stmt.where(Notebook.hunt_id == hunt_id)
count_stmt = count_stmt.where(Notebook.hunt_id == hunt_id)
total = (await db.execute(count_stmt)).scalar() or 0
results = (await db.execute(
stmt.order_by(desc(Notebook.updated_at)).offset(offset).limit(limit)
)).scalars().all()
return {"notebooks": [_notebook_to_dict(n) for n in results], "total": total}
@router.get("/{notebook_id}", summary="Get notebook")
async def get_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
return _notebook_to_dict(nb)
@router.post("", summary="Create notebook")
async def create_notebook(body: NotebookCreate, db: AsyncSession = Depends(get_db)):
cells = validate_notebook_cells(body.cells or [])
if not cells:
# Start with a default markdown cell
cells = [{"id": "cell-0", "cell_type": "markdown", "source": "# Investigation Notes\n\nStart documenting your findings here.", "output": None, "metadata": {}}]
nb = Notebook(
id=_new_id(),
title=body.title,
description=body.description,
cells=cells,
hunt_id=body.hunt_id,
case_id=body.case_id,
tags=body.tags,
)
db.add(nb)
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.put("/{notebook_id}", summary="Update notebook")
async def update_notebook(
notebook_id: str, body: NotebookUpdate, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
if body.title is not None:
nb.title = body.title
if body.description is not None:
nb.description = body.description
if body.cells is not None:
nb.cells = validate_notebook_cells(body.cells)
if body.tags is not None:
nb.tags = body.tags
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.post("/{notebook_id}/cells", summary="Add or update a cell")
async def upsert_cell(
notebook_id: str, body: CellUpdate, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
cells = list(nb.cells or [])
found = False
for i, c in enumerate(cells):
if c.get("id") == body.cell_id:
if body.cell_type is not None:
cells[i]["cell_type"] = body.cell_type
if body.source is not None:
cells[i]["source"] = body.source
if body.output is not None:
cells[i]["output"] = body.output
if body.metadata is not None:
cells[i]["metadata"] = body.metadata
found = True
break
if not found:
cells.append({
"id": body.cell_id,
"cell_type": body.cell_type or "markdown",
"source": body.source or "",
"output": body.output,
"metadata": body.metadata or {},
})
nb.cells = cells
await db.commit()
await db.refresh(nb)
return _notebook_to_dict(nb)
@router.delete("/{notebook_id}/cells/{cell_id}", summary="Delete a cell")
async def delete_cell(
notebook_id: str, cell_id: str, db: AsyncSession = Depends(get_db)
):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
cells = [c for c in (nb.cells or []) if c.get("id") != cell_id]
nb.cells = cells
await db.commit()
return {"ok": True, "remaining_cells": len(cells)}
@router.delete("/{notebook_id}", summary="Delete notebook")
async def delete_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
nb = await db.get(Notebook, notebook_id)
if not nb:
raise HTTPException(status_code=404, detail="Notebook not found")
await db.delete(nb)
await db.commit()
return {"ok": True}
# ── Playbooks ─────────────────────────────────────────────────────────
@router.get("/playbooks/templates", summary="List built-in playbook templates")
async def list_playbook_templates():
templates = get_builtin_playbooks()
return {
"templates": [
{
"name": t["name"],
"description": t["description"],
"category": t["category"],
"tags": t["tags"],
"step_count": len(t["steps"]),
}
for t in templates
]
}
@router.get("/playbooks/templates/{name}", summary="Get playbook template detail")
async def get_playbook_template_detail(name: str):
template = get_playbook_template(name)
if not template:
raise HTTPException(status_code=404, detail="Playbook template not found")
return template
@router.post("/playbooks/start", summary="Start a playbook run")
async def start_playbook(body: PlaybookStart, db: AsyncSession = Depends(get_db)):
template = get_playbook_template(body.playbook_name)
if not template:
raise HTTPException(status_code=404, detail="Playbook template not found")
run = PlaybookRun(
id=_new_id(),
playbook_name=body.playbook_name,
status="in-progress",
current_step=1,
total_steps=len(template["steps"]),
step_results=[],
hunt_id=body.hunt_id,
case_id=body.case_id,
started_by=body.started_by,
)
db.add(run)
await db.commit()
await db.refresh(run)
return _run_to_dict(run)
@router.get("/playbooks/runs", summary="List playbook runs")
async def list_playbook_runs(
status: str | None = Query(None),
hunt_id: str | None = Query(None),
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
):
stmt = select(PlaybookRun)
if status:
stmt = stmt.where(PlaybookRun.status == status)
if hunt_id:
stmt = stmt.where(PlaybookRun.hunt_id == hunt_id)
results = (await db.execute(
stmt.order_by(desc(PlaybookRun.created_at)).limit(limit)
)).scalars().all()
return {"runs": [_run_to_dict(r) for r in results]}
@router.get("/playbooks/runs/{run_id}", summary="Get playbook run detail")
async def get_playbook_run(run_id: str, db: AsyncSession = Depends(get_db)):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
# Also include the template steps
template = get_playbook_template(run.playbook_name)
result = _run_to_dict(run)
result["steps"] = template["steps"] if template else []
return result
@router.post("/playbooks/runs/{run_id}/complete-step", summary="Complete current playbook step")
async def complete_step(
run_id: str, body: StepComplete, db: AsyncSession = Depends(get_db)
):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
if run.status != "in-progress":
raise HTTPException(status_code=400, detail="Run is not in progress")
step_results = list(run.step_results or [])
step_results.append({
"step": run.current_step,
"status": body.status,
"notes": body.notes,
"completed_at": _utcnow().isoformat(),
})
run.step_results = step_results
if run.current_step >= run.total_steps:
run.status = "completed"
run.completed_at = _utcnow()
else:
run.current_step += 1
await db.commit()
await db.refresh(run)
return _run_to_dict(run)
@router.post("/playbooks/runs/{run_id}/abort", summary="Abort a playbook run")
async def abort_run(run_id: str, db: AsyncSession = Depends(get_db)):
run = await db.get(PlaybookRun, run_id)
if not run:
raise HTTPException(status_code=404, detail="Run not found")
run.status = "aborted"
run.completed_at = _utcnow()
await db.commit()
return _run_to_dict(run)

View File

@@ -0,0 +1,217 @@
"""API routes for investigation playbooks."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Playbook, PlaybookStep
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
# -- Request / Response schemas ---
class StepCreate(BaseModel):
title: str
description: str | None = None
step_type: str = "manual"
target_route: str | None = None
class PlaybookCreate(BaseModel):
name: str
description: str | None = None
hunt_id: str | None = None
is_template: bool = False
steps: list[StepCreate] = []
class PlaybookUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None
class StepUpdate(BaseModel):
is_completed: bool | None = None
notes: str | None = None
# -- Default investigation templates ---
DEFAULT_TEMPLATES = [
{
"name": "Standard Threat Hunt",
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
"steps": [
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
],
},
{
"name": "Incident Response Triage",
"description": "Fast-track workflow for active incident response.",
"steps": [
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
],
},
]
# -- Routes ---
@router.get("")
async def list_playbooks(
include_templates: bool = True,
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(Playbook)
if hunt_id:
q = q.where(Playbook.hunt_id == hunt_id)
if not include_templates:
q = q.where(Playbook.is_template == False)
q = q.order_by(Playbook.created_at.desc())
result = await db.execute(q.limit(100))
playbooks = result.scalars().all()
return {"playbooks": [
{
"id": p.id, "name": p.name, "description": p.description,
"is_template": p.is_template, "hunt_id": p.hunt_id,
"status": p.status,
"total_steps": len(p.steps),
"completed_steps": sum(1 for s in p.steps if s.is_completed),
"created_at": p.created_at.isoformat() if p.created_at else None,
}
for p in playbooks
]}
@router.get("/templates")
async def get_templates():
"""Return built-in investigation templates."""
return {"templates": DEFAULT_TEMPLATES}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
pb = Playbook(
name=body.name,
description=body.description,
hunt_id=body.hunt_id,
is_template=body.is_template,
)
db.add(pb)
await db.flush()
created_steps = []
for i, step in enumerate(body.steps):
s = PlaybookStep(
playbook_id=pb.id,
order_index=i,
title=step.title,
description=step.description,
step_type=step.step_type,
target_route=step.target_route,
)
db.add(s)
created_steps.append(s)
await db.flush()
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
"steps": [
{"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route, "is_completed": False}
for s in created_steps
],
}
@router.get("/{playbook_id}")
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
"status": pb.status,
"created_at": pb.created_at.isoformat() if pb.created_at else None,
"steps": [
{
"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route,
"is_completed": s.is_completed,
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
"notes": s.notes,
}
for s in sorted(pb.steps, key=lambda x: x.order_index)
],
}
@router.put("/{playbook_id}")
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
if body.name is not None:
pb.name = body.name
if body.description is not None:
pb.description = body.description
if body.status is not None:
pb.status = body.status
return {"status": "updated"}
@router.delete("/{playbook_id}")
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
await db.delete(pb)
return {"status": "deleted"}
@router.put("/steps/{step_id}")
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
step = result.scalar_one_or_none()
if not step:
raise HTTPException(status_code=404, detail="Step not found")
if body.is_completed is not None:
step.is_completed = body.is_completed
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
if body.notes is not None:
step.notes = body.notes
return {"status": "updated", "is_completed": step.is_completed}

View File

@@ -0,0 +1,164 @@
"""API routes for saved searches and bookmarked queries."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import SavedSearch
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
class SearchCreate(BaseModel):
name: str
description: str | None = None
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
query_params: dict
threshold: float | None = None
class SearchUpdate(BaseModel):
name: str | None = None
description: str | None = None
query_params: dict | None = None
threshold: float | None = None
@router.get("")
async def list_searches(
search_type: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
if search_type:
q = q.where(SavedSearch.search_type == search_type)
result = await db.execute(q.limit(100))
searches = result.scalars().all()
return {"searches": [
{
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in searches
]}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
s = SavedSearch(
name=body.name,
description=body.description,
search_type=body.search_type,
query_params=body.query_params,
threshold=body.threshold,
)
db.add(s)
await db.flush()
return {
"id": s.id, "name": s.name, "search_type": s.search_type,
"query_params": s.query_params, "threshold": s.threshold,
}
@router.get("/{search_id}")
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
return {
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
@router.put("/{search_id}")
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
if body.name is not None:
s.name = body.name
if body.description is not None:
s.description = body.description
if body.query_params is not None:
s.query_params = body.query_params
if body.threshold is not None:
s.threshold = body.threshold
return {"status": "updated"}
@router.delete("/{search_id}")
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
await db.delete(s)
return {"status": "deleted"}
@router.post("/{search_id}/run")
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
"""Execute a saved search and return results with delta from last run."""
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
previous_count = s.last_result_count or 0
results = []
count = 0
if s.search_type == "ioc_search":
from app.db.models import EnrichmentResult
ioc_value = s.query_params.get("ioc_value", "")
if ioc_value:
q = select(EnrichmentResult).where(
EnrichmentResult.ioc_value.contains(ioc_value)
)
res = await db.execute(q.limit(100))
for er in res.scalars().all():
results.append({
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
"source": er.source, "verdict": er.verdict,
})
count = len(results)
elif s.search_type == "keyword_scan":
from app.db.models import KeywordTheme
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
themes = res.scalars().all()
count = sum(len(t.keywords) for t in themes)
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
# Update last run metadata
s.last_run_at = datetime.now(timezone.utc)
s.last_result_count = count
delta = count - previous_count
return {
"search_id": s.id, "search_name": s.name,
"search_type": s.search_type,
"result_count": count,
"previous_count": previous_count,
"delta": delta,
"results": results[:50],
}

View File

@@ -0,0 +1,184 @@
"""STIX 2.1 export endpoint.
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
STIX 2.1 Bundle JSON download. No external dependencies required we
build the JSON directly following the OASIS STIX 2.1 spec.
"""
import json
import uuid
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
EnrichmentResult, HuntReport,
)
router = APIRouter(prefix="/api/export", tags=["export"])
STIX_SPEC_VERSION = "2.1"
def _stix_id(stype: str) -> str:
return f"{stype}--{uuid.uuid4()}"
def _now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
def _build_identity(hunt_name: str) -> dict:
return {
"type": "identity",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("identity"),
"created": _now_iso(),
"modified": _now_iso(),
"name": f"ThreatHunt - {hunt_name}",
"identity_class": "system",
}
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
pattern_map = {
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
"domain": f"[domain-name:value = '{ioc_value}']",
"url": f"[url:value = '{ioc_value}']",
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
"email": f"[email-addr:value = '{ioc_value}']",
}
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
now = _now_iso()
return {
"type": "indicator",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("indicator"),
"created": now,
"modified": now,
"name": f"{ioc_type}: {ioc_value}",
"pattern": pattern,
"pattern_type": "stix",
"valid_from": now,
"created_by_ref": identity_id,
"labels": [verdict or "suspicious"],
}
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "attack-pattern",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("attack-pattern"),
"created": now,
"modified": now,
"name": technique_id,
"created_by_ref": identity_id,
"external_references": [
{
"source_name": "mitre-attack",
"external_id": technique_id,
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
}
],
}
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "report",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("report"),
"created": now,
"modified": now,
"name": hyp.title,
"description": hyp.description or "",
"published": now,
"created_by_ref": identity_id,
"labels": ["threat-hunt-hypothesis"],
"object_refs": [],
}
@router.get("/stix/{hunt_id}")
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
# Fetch hunt
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
if not hunt:
raise HTTPException(404, "Hunt not found")
identity = _build_identity(hunt.name)
objects: list[dict] = [identity]
seen_techniques: set[str] = set()
seen_iocs: set[str] = set()
# Gather IOCs from enrichment results for hunt's datasets
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
ds_ids = [r[0] for r in datasets_q.all()]
if ds_ids:
enrichments = (await db.execute(
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
)).scalars().all()
for e in enrichments:
key = f"{e.ioc_type}:{e.ioc_value}"
if key not in seen_iocs:
seen_iocs.add(key)
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
# Gather techniques from triage results
triages = (await db.execute(
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
)).scalars().all()
for t in triages:
for tech in (t.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather techniques from host profiles
profiles = (await db.execute(
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
)).scalars().all()
for p in profiles:
for tech in (p.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather hypotheses
hypos = (await db.execute(
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
)).scalars().all()
for h in hypos:
objects.append(_hypothesis_to_report(h, identity["id"]))
if h.mitre_technique and h.mitre_technique not in seen_techniques:
seen_techniques.add(h.mitre_technique)
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
bundle = {
"type": "bundle",
"id": _stix_id("bundle"),
"objects": objects,
}
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
return Response(
content=json.dumps(bundle, indent=2),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

View File

@@ -0,0 +1,128 @@
"""API routes for forensic timeline visualization."""
import logging
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Dataset, DatasetRow, Hunt
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
def _parse_timestamp(val: str | None) -> str | None:
"""Try to parse a timestamp string, return ISO format or None."""
if not val:
return None
val = str(val).strip()
if not val:
return None
# Try common formats
for fmt in [
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
]:
try:
return datetime.strptime(val, fmt).isoformat() + "Z"
except ValueError:
continue
return None
# Columns likely to contain timestamps
TIME_COLUMNS = {
"timestamp", "time", "datetime", "date", "created", "modified",
"eventtime", "event_time", "start_time", "end_time",
"lastmodified", "last_modified", "created_at", "updated_at",
"mtime", "atime", "ctime", "btime",
"timecreated", "timegenerated", "sourcetime",
}
@router.get("/hunt/{hunt_id}")
async def get_hunt_timeline(
hunt_id: str,
limit: int = 2000,
db: AsyncSession = Depends(get_db),
):
"""Build a timeline of events across all datasets in a hunt."""
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
datasets = result.scalars().all()
if not datasets:
return {"hunt_id": hunt_id, "events": [], "datasets": []}
events = []
dataset_info = []
for ds in datasets:
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
dataset_info.append({
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
"row_count": ds.row_count,
})
# Find time columns for this dataset
schema = ds.column_schema or {}
time_cols = []
for col in (ds.normalized_columns or {}).values():
if col.lower() in TIME_COLUMNS:
time_cols.append(col)
if not time_cols:
for col in schema:
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
time_cols.append(col)
if not time_cols:
continue
# Fetch rows
rows_result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.limit(limit // max(len(datasets), 1))
)
for r in rows_result.scalars().all():
data = r.normalized_data or r.data
ts = None
for tc in time_cols:
ts = _parse_timestamp(data.get(tc))
if ts:
break
if ts:
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
events.append({
"timestamp": ts,
"dataset_id": ds.id,
"dataset_name": ds.name,
"artifact_type": artifact_type,
"row_index": r.row_index,
"hostname": str(hostname)[:128],
"process": str(process)[:128],
"summary": str(summary)[:256],
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
})
# Sort by timestamp
events.sort(key=lambda e: e["timestamp"])
return {
"hunt_id": hunt_id,
"hunt_name": hunt.name,
"event_count": len(events),
"datasets": dataset_info,
"events": events[:limit],
}

Some files were not shown because too many files have changed in this diff Show More