Compare commits
10 Commits
copilot/im
...
7c454036c7
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c454036c7 | |||
| 365cf87c90 | |||
| bb562a91ca | |||
| 04a9946891 | |||
| ab8038867a | |||
| 9b98ab9614 | |||
| d0c9f88268 | |||
| dc2dcd02c1 | |||
| 73a2efcde3 | |||
| 77509b08f5 |
53
.env.example
Normal file
@@ -0,0 +1,53 @@
|
||||
# ── ThreatHunt Configuration ──────────────────────────────────────────
|
||||
# All backend env vars are prefixed with TH_ and match AppConfig field names.
|
||||
# Copy this file to .env and adjust values.
|
||||
|
||||
# ── General ───────────────────────────────────────────────────────────
|
||||
TH_DEBUG=false
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────
|
||||
# SQLite for local dev (zero-config):
|
||||
TH_DATABASE_URL=sqlite+aiosqlite:///./threathunt.db
|
||||
# PostgreSQL for production:
|
||||
# TH_DATABASE_URL=postgresql+asyncpg://threathunt:password@localhost:5432/threathunt
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────
|
||||
TH_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:8000
|
||||
|
||||
# ── File uploads ──────────────────────────────────────────────────────
|
||||
TH_MAX_UPLOAD_SIZE_MB=500
|
||||
|
||||
# ── LLM Cluster (Wile & Roadrunner) ──────────────────────────────────
|
||||
TH_OPENWEBUI_URL=https://ai.guapo613.beer
|
||||
TH_OPENWEBUI_API_KEY=
|
||||
TH_WILE_HOST=100.110.190.12
|
||||
TH_WILE_OLLAMA_PORT=11434
|
||||
TH_ROADRUNNER_HOST=100.110.190.11
|
||||
TH_ROADRUNNER_OLLAMA_PORT=11434
|
||||
|
||||
# ── Default models (auto-selected by TaskRouter) ─────────────────────
|
||||
TH_DEFAULT_FAST_MODEL=llama3.1:latest
|
||||
TH_DEFAULT_HEAVY_MODEL=llama3.1:70b-instruct-q4_K_M
|
||||
TH_DEFAULT_CODE_MODEL=qwen2.5-coder:32b
|
||||
TH_DEFAULT_VISION_MODEL=llama3.2-vision:11b
|
||||
TH_DEFAULT_EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# ── Agent behaviour ──────────────────────────────────────────────────
|
||||
TH_AGENT_MAX_TOKENS=2048
|
||||
TH_AGENT_TEMPERATURE=0.3
|
||||
TH_AGENT_HISTORY_LENGTH=10
|
||||
TH_FILTER_SENSITIVE_DATA=true
|
||||
|
||||
# ── Enrichment API keys (optional) ───────────────────────────────────
|
||||
TH_VIRUSTOTAL_API_KEY=
|
||||
TH_ABUSEIPDB_API_KEY=
|
||||
TH_SHODAN_API_KEY=
|
||||
|
||||
# ── Auth ─────────────────────────────────────────────────────────────
|
||||
TH_JWT_SECRET=CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET
|
||||
TH_JWT_ACCESS_TOKEN_MINUTES=60
|
||||
TH_JWT_REFRESH_TOKEN_DAYS=7
|
||||
|
||||
# ── Frontend ─────────────────────────────────────────────────────────
|
||||
REACT_APP_API_URL=http://localhost:8000
|
||||
|
||||
56
.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
# ── Python ────────────────────────────────────
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
.eggs/
|
||||
|
||||
# ── Virtual environments ─────────────────────
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
|
||||
# ── IDE / Editor ─────────────────────────────
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# ── OS ────────────────────────────────────────
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# ── Environment / Secrets ────────────────────
|
||||
.env
|
||||
*.env.local
|
||||
|
||||
# ── Database ─────────────────────────────────
|
||||
*.db
|
||||
*.sqlite3
|
||||
|
||||
# ── Uploads ──────────────────────────────────
|
||||
uploads/
|
||||
|
||||
# ── Node / Frontend ──────────────────────────
|
||||
node_modules/
|
||||
frontend/build/
|
||||
frontend/.env.local
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# ── Docker ───────────────────────────────────
|
||||
docker-compose.override.yml
|
||||
|
||||
# ── Test / Coverage ──────────────────────────
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# ── Alembic ──────────────────────────────────
|
||||
alembic/versions/*.pyc
|
||||
1
.playwright-mcp/console-2026-02-20T16-32-53-248Z.log
Normal file
@@ -0,0 +1 @@
|
||||
[ 656ms] [WARNING] No routes matched location "/network-map" @ http://localhost:3000/static/js/main.c0a7ab6d.js:1
|
||||
1
.playwright-mcp/console-2026-02-20T18-16-44-089Z.log
Normal file
@@ -0,0 +1 @@
|
||||
[ 4269ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.6d916bcf.js:1
|
||||
1
.playwright-mcp/console-2026-02-20T18-26-05-692Z.log
Normal file
@@ -0,0 +1 @@
|
||||
[ 496ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.28ae077d.js:1
|
||||
76
.playwright-mcp/console-2026-02-20T18-30-45-724Z.log
Normal file
@@ -0,0 +1,76 @@
|
||||
[ 402ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 60389ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
|
||||
[ 114742ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
|
||||
[ 116603ms] [ERROR] Failed to load resource: the server responded with a status of 500 (Internal Server Error) @ http://localhost:3000/api/analysis/process-tree?hunt_id=4bb956a4225e45459a464da1146d3cf5:0
|
||||
[ 362021ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 379006ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 379019ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 379021ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
[ 382647ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 386088ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 386343ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 386345ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
[ 397704ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 519009ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 519273ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785) @ http://localhost:3000/static/js/main.cb47c3a0.js:1
|
||||
[ 519274ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.cb47c3a0.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.cb47c3a0.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.cb47c3a0.js:2:228785)
|
||||
1
.playwright-mcp/console-2026-02-20T18-44-41-738Z.log
Normal file
@@ -0,0 +1 @@
|
||||
[ 1803ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.b2c21c5a.js:1
|
||||
48
.playwright-mcp/console-2026-02-20T18-46-54-542Z.log
Normal file
@@ -0,0 +1,48 @@
|
||||
[ 2196ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 46100ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 46117ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785) @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 46118ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
[ 52506ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 54912ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 54928ms] [ERROR] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785) @ http://localhost:3000/static/js/main.0e63bc98.js:1
|
||||
[ 54929ms] NotFoundError: Failed to execute 'removeChild' on 'Node': The node to be removed is not a child of this node.
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227378)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at ds (http://localhost:3000/static/js/main.0e63bc98.js:2:227062)
|
||||
at ps (http://localhost:3000/static/js/main.0e63bc98.js:2:227824)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228635)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:229095)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
at vs (http://localhost:3000/static/js/main.0e63bc98.js:2:228898)
|
||||
at hs (http://localhost:3000/static/js/main.0e63bc98.js:2:228785)
|
||||
7
.playwright-mcp/console-2026-02-20T18-50-52-269Z.log
Normal file
@@ -0,0 +1,7 @@
|
||||
[ 2548ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
[ 32912ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
[ 55583ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
[ 58208ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
[ 1168933ms] [ERROR] Failed to load resource: the server responded with a status of 504 (Gateway Time-out) @ http://localhost:3000/api/analysis/llm-analyze:0
|
||||
[ 1477343ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
[ 1482908ms] [WARNING] You have set a custom wheel sensitivity. This will make your app zoom unnaturally when using mainstream mice. You should change this value from the default only if you can guarantee that all your users will use the same hardware and OS configuration as your current machine. @ http://localhost:3000/static/js/main.c311038e.js:1
|
||||
7
.playwright-mcp/console-2026-02-20T19-16-43-503Z.log
Normal file
@@ -0,0 +1,7 @@
|
||||
[ 9612ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/:0
|
||||
[ 17464ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/enterprise:0
|
||||
[ 20742ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/enterprise:0
|
||||
[ 53258ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/pricing:0
|
||||
[ 59240ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/copilot#pricing:0
|
||||
[ 67668ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/spark?utm_source=web-copilot-ce-cta&utm_campaign=spark-launch-sep-2025:0
|
||||
[ 72166ms] [WARNING] The resource https://github.githubassets.com/assets/mona-sans-14595085164a.woff2 was preloaded using link preload but not used within a few seconds from the window's load event. Please make sure it has an appropriate `as` value and it is preloaded intentionally. @ https://github.com/features/spark?utm_source=web-copilot-ce-cta&utm_campaign=spark-launch-sep-2025:0
|
||||
3923
.playwright-mcp/console-2026-02-20T19-27-06-976Z.log
Normal file
BIN
.playwright-mcp/page-2026-02-20T16-33-40-311Z.png
Normal file
|
After Width: | Height: | Size: 41 KiB |
BIN
.playwright-mcp/page-2026-02-20T16-34-14-809Z.png
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
.playwright-mcp/page-2026-02-20T16-38-20-099Z.png
Normal file
|
After Width: | Height: | Size: 70 KiB |
BIN
.playwright-mcp/page-2026-02-20T16-42-11-611Z.png
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-17-19-668Z.png
Normal file
|
After Width: | Height: | Size: 558 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-26-49-357Z.png
Normal file
|
After Width: | Height: | Size: 607 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-31-29-013Z.png
Normal file
|
After Width: | Height: | Size: 341 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-44-58-287Z.png
Normal file
|
After Width: | Height: | Size: 53 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-45-12-934Z.png
Normal file
|
After Width: | Height: | Size: 55 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-47-14-660Z.png
Normal file
|
After Width: | Height: | Size: 193 KiB |
BIN
.playwright-mcp/page-2026-02-20T18-51-32-804Z.png
Normal file
|
After Width: | Height: | Size: 184 KiB |
32
Dockerfile.backend
Normal file
@@ -0,0 +1,32 @@
|
||||
# ThreatHunt Backend API - Python 3.13
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY backend/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy backend code
|
||||
COPY backend/ .
|
||||
|
||||
# Create non-root user & data directory
|
||||
RUN useradd -m -u 1000 appuser && mkdir -p /app/data && chown -R appuser:appuser /app
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/ || exit 1
|
||||
|
||||
# Run Alembic migrations then start Uvicorn
|
||||
CMD ["sh", "-c", "python -m alembic upgrade head && python run.py"]
|
||||
36
Dockerfile.frontend
Normal file
@@ -0,0 +1,36 @@
|
||||
# ThreatHunt Frontend - Node.js React
|
||||
FROM node:20-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy package files
|
||||
COPY frontend/package.json frontend/package-lock.json* ./
|
||||
|
||||
# Install dependencies
|
||||
RUN npm ci
|
||||
|
||||
# Copy source
|
||||
COPY frontend/public ./public
|
||||
COPY frontend/src ./src
|
||||
COPY frontend/tsconfig.json ./
|
||||
|
||||
# Build application
|
||||
RUN npm run build
|
||||
|
||||
# Production stage — nginx reverse-proxy + static files
|
||||
FROM nginx:alpine
|
||||
|
||||
# Copy built React app
|
||||
COPY --from=builder /app/build /usr/share/nginx/html
|
||||
|
||||
# Copy custom nginx config (proxies /api to backend)
|
||||
COPY frontend/nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
# Expose port
|
||||
EXPOSE 3000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD wget --quiet --tries=1 --spider http://localhost:3000/ || exit 1
|
||||
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
497
README.md
@@ -1 +1,496 @@
|
||||
# ThreatHunt
|
||||
# ThreatHunt - Analyst-Assist Threat Hunting Platform
|
||||
|
||||
A modern threat hunting platform with integrated analyst-assist agent guidance. Analyze CSV artifact data exported from Velociraptor with AI-powered suggestions for investigation directions, analytical pivots, and hypothesis formation.
|
||||
|
||||
## Overview
|
||||
|
||||
ThreatHunt is a web application designed to help security analysts efficiently hunt for threats by:
|
||||
- Importing CSV artifacts from Velociraptor or other sources
|
||||
- Displaying data in an organized, queryable interface
|
||||
- Providing AI-powered guidance through an analyst-assist agent
|
||||
- Suggesting analytical directions, filters, and pivots
|
||||
- Highlighting anomalies and patterns of interest
|
||||
|
||||
> **Agent Policy**: The analyst-assist agent provides read-only guidance only. It does not execute actions, escalate alerts, or modify data. All decisions remain with the analyst.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Docker (Recommended)
|
||||
|
||||
```bash
|
||||
# Clone and navigate
|
||||
git clone https://github.com/mblanke/ThreatHunt.git
|
||||
cd ThreatHunt
|
||||
|
||||
# Configure provider (choose one)
|
||||
cp .env.example .env
|
||||
# Edit .env and set your LLM provider:
|
||||
# Option 1: Online (OpenAI, etc.)
|
||||
# THREAT_HUNT_AGENT_PROVIDER=online
|
||||
# THREAT_HUNT_ONLINE_API_KEY=sk-your-key
|
||||
# Option 2: Local (Ollama, GGML, etc.)
|
||||
# THREAT_HUNT_AGENT_PROVIDER=local
|
||||
# THREAT_HUNT_LOCAL_MODEL_PATH=/path/to/model
|
||||
# Option 3: Networked (Internal inference service)
|
||||
# THREAT_HUNT_AGENT_PROVIDER=networked
|
||||
# THREAT_HUNT_NETWORKED_ENDPOINT=http://service:5000
|
||||
|
||||
# Start services
|
||||
docker-compose up -d
|
||||
|
||||
# Verify
|
||||
curl http://localhost:8000/api/agent/health
|
||||
curl http://localhost:3000
|
||||
```
|
||||
|
||||
Access at http://localhost:3000
|
||||
|
||||
### Local Development
|
||||
|
||||
**Backend**:
|
||||
```bash
|
||||
cd backend
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Configure provider
|
||||
export THREAT_HUNT_ONLINE_API_KEY=sk-your-key
|
||||
# OR set another provider env var
|
||||
|
||||
# Run
|
||||
python run.py
|
||||
# API at http://localhost:8000/docs
|
||||
```
|
||||
|
||||
**Frontend** (new terminal):
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm start
|
||||
# App at http://localhost:3000
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### Analyst-Assist Agent 🤖
|
||||
- **Read-only guidance**: Explains data patterns and suggests investigation directions
|
||||
- **Context-aware**: Understands current dataset, host, and artifact type
|
||||
- **Pluggable providers**: Local, networked, or online LLM backends
|
||||
- **Transparent reasoning**: Explains logic with caveats and confidence scores
|
||||
- **Governance-compliant**: Strictly adheres to agent policy (no execution, no escalation)
|
||||
|
||||
### Chat Interface
|
||||
- Analyst asks questions about artifact data
|
||||
- Agent provides guidance with suggested pivots and filters
|
||||
- Conversation history for context continuity
|
||||
- Real-time typing and response indicators
|
||||
|
||||
### Data Management
|
||||
- Import CSV artifacts from Velociraptor
|
||||
- Browse and filter findings by severity, host, artifact type
|
||||
- Annotate findings with analyst notes
|
||||
- Track investigation progress
|
||||
|
||||
## Architecture
|
||||
|
||||
### Backend
|
||||
- **Framework**: FastAPI (Python 3.11)
|
||||
- **Agent Module**: Pluggable LLM provider interface
|
||||
- **API**: RESTful endpoints with OpenAPI documentation
|
||||
- **Structure**: Modular design with clear separation of concerns
|
||||
|
||||
### Frontend
|
||||
- **Framework**: React 18 with TypeScript
|
||||
- **Components**: Agent chat panel + analysis dashboard
|
||||
- **Styling**: CSS with responsive design
|
||||
- **State Management**: React hooks + Context API
|
||||
|
||||
### LLM Providers
|
||||
Supports three provider architectures:
|
||||
|
||||
1. **Local**: On-device or on-prem models (GGML, Ollama, vLLM)
|
||||
2. **Networked**: Shared internal inference services
|
||||
3. **Online**: External hosted APIs (OpenAI, Anthropic, Google)
|
||||
|
||||
Auto-detection: Automatically uses the first available provider.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
ThreatHunt/
|
||||
├── backend/
|
||||
│ ├── app/
|
||||
│ │ ├── agents/ # Analyst-assist agent
|
||||
│ │ │ ├── core.py # ThreatHuntAgent class
|
||||
│ │ │ ├── providers.py # LLM provider interface
|
||||
│ │ │ ├── config.py # Configuration
|
||||
│ │ │ └── __init__.py
|
||||
│ │ ├── api/routes/ # API endpoints
|
||||
│ │ │ ├── agent.py # /api/agent/* routes
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ ├── main.py # FastAPI app
|
||||
│ │ └── __init__.py
|
||||
│ ├── requirements.txt
|
||||
│ ├── run.py
|
||||
│ └── Dockerfile
|
||||
├── frontend/
|
||||
│ ├── src/
|
||||
│ │ ├── components/
|
||||
│ │ │ ├── AgentPanel.tsx # Chat interface
|
||||
│ │ │ └── AgentPanel.css
|
||||
│ │ ├── utils/
|
||||
│ │ │ └── agentApi.ts # API communication
|
||||
│ │ ├── App.tsx
|
||||
│ │ ├── App.css
|
||||
│ │ ├── index.tsx
|
||||
│ │ └── index.css
|
||||
│ ├── public/index.html
|
||||
│ ├── package.json
|
||||
│ ├── tsconfig.json
|
||||
│ └── Dockerfile
|
||||
├── docker-compose.yml
|
||||
├── .env.example
|
||||
├── .gitignore
|
||||
├── AGENT_IMPLEMENTATION.md # Technical guide
|
||||
├── INTEGRATION_GUIDE.md # Deployment guide
|
||||
├── IMPLEMENTATION_SUMMARY.md # Overview
|
||||
├── README.md # This file
|
||||
├── ROADMAP.md
|
||||
└── THREATHUNT_INTENT.md
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Agent Assistance
|
||||
- **POST /api/agent/assist** - Request guidance on artifact data
|
||||
- **GET /api/agent/health** - Check agent availability
|
||||
|
||||
See full API documentation at http://localhost:8000/docs
|
||||
|
||||
## Configuration
|
||||
|
||||
### LLM Provider Selection
|
||||
|
||||
Set via `THREAT_HUNT_AGENT_PROVIDER` environment variable:
|
||||
|
||||
```bash
|
||||
# Auto-detect (tries local → networked → online)
|
||||
THREAT_HUNT_AGENT_PROVIDER=auto
|
||||
|
||||
# Local (on-device/on-prem)
|
||||
THREAT_HUNT_AGENT_PROVIDER=local
|
||||
THREAT_HUNT_LOCAL_MODEL_PATH=/models/model.gguf
|
||||
|
||||
# Networked (internal service)
|
||||
THREAT_HUNT_AGENT_PROVIDER=networked
|
||||
THREAT_HUNT_NETWORKED_ENDPOINT=http://inference:5000
|
||||
THREAT_HUNT_NETWORKED_KEY=api-key
|
||||
|
||||
# Online (hosted API)
|
||||
THREAT_HUNT_AGENT_PROVIDER=online
|
||||
THREAT_HUNT_ONLINE_API_KEY=sk-your-key
|
||||
THREAT_HUNT_ONLINE_PROVIDER=openai
|
||||
THREAT_HUNT_ONLINE_MODEL=gpt-3.5-turbo
|
||||
```
|
||||
|
||||
### Agent Behavior
|
||||
|
||||
```bash
|
||||
THREAT_HUNT_AGENT_MAX_TOKENS=1024
|
||||
THREAT_HUNT_AGENT_REASONING=true
|
||||
THREAT_HUNT_AGENT_HISTORY_LENGTH=10
|
||||
THREAT_HUNT_AGENT_FILTER_SENSITIVE=true
|
||||
```
|
||||
|
||||
See `.env.example` for all configuration options.
|
||||
|
||||
## Governance & Compliance
|
||||
|
||||
This implementation strictly follows governance principles:
|
||||
|
||||
- ✅ **Agents assist analysts** - No autonomous execution
|
||||
- ✅ **No tool execution** - Agent provides guidance only
|
||||
- ✅ **No alert escalation** - Analyst controls alerts
|
||||
- ✅ **No data modification** - Read-only analysis
|
||||
- ✅ **Transparent reasoning** - Explains guidance with caveats
|
||||
- ✅ **Analyst authority** - All decisions remain with analyst
|
||||
|
||||
**References**:
|
||||
- `goose-core/governance/AGENT_POLICY.md`
|
||||
- `goose-core/governance/AI_RULES.md`
|
||||
- `THREATHUNT_INTENT.md`
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[AGENT_IMPLEMENTATION.md](AGENT_IMPLEMENTATION.md)** - Detailed technical architecture
|
||||
- **[INTEGRATION_GUIDE.md](INTEGRATION_GUIDE.md)** - Deployment and configuration
|
||||
- **[IMPLEMENTATION_SUMMARY.md](IMPLEMENTATION_SUMMARY.md)** - Feature overview
|
||||
|
||||
## Testing the Agent
|
||||
|
||||
### Check Health
|
||||
```bash
|
||||
curl http://localhost:8000/api/agent/health
|
||||
```
|
||||
|
||||
### Test API
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/api/agent/assist \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"query": "What patterns suggest suspicious activity?",
|
||||
"dataset_name": "FileList",
|
||||
"artifact_type": "FileList",
|
||||
"host_identifier": "DESKTOP-ABC123"
|
||||
}'
|
||||
```
|
||||
|
||||
### Use UI
|
||||
1. Open http://localhost:3000
|
||||
2. Enter a question in the agent panel
|
||||
3. View guidance with suggested pivots and filters
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Agent Unavailable (503)
|
||||
- Check environment variables for provider configuration
|
||||
- Verify LLM provider is accessible
|
||||
- See logs: `docker-compose logs backend`
|
||||
|
||||
### No Frontend Response
|
||||
- Verify backend health: `curl http://localhost:8000/api/agent/health`
|
||||
- Check browser console for errors
|
||||
- See logs: `docker-compose logs frontend`
|
||||
|
||||
See [INTEGRATION_GUIDE.md](INTEGRATION_GUIDE.md) for detailed troubleshooting.
|
||||
|
||||
## Development
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
cd backend
|
||||
pytest
|
||||
|
||||
cd ../frontend
|
||||
npm test
|
||||
```
|
||||
|
||||
### Building Images
|
||||
```bash
|
||||
docker-compose build
|
||||
```
|
||||
|
||||
### Logs
|
||||
```bash
|
||||
docker-compose logs -f backend
|
||||
docker-compose logs -f frontend
|
||||
```
|
||||
|
||||
## Security Notes
|
||||
|
||||
For production deployment:
|
||||
1. Add authentication to API endpoints
|
||||
2. Enable HTTPS/TLS
|
||||
3. Implement rate limiting
|
||||
4. Filter sensitive data before LLM
|
||||
5. Add audit logging
|
||||
6. Use secrets management for API keys
|
||||
|
||||
See [INTEGRATION_GUIDE.md](INTEGRATION_GUIDE.md#security-notes) for details.
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [ ] Integration with actual CVE databases
|
||||
- [ ] Fine-tuned models for cybersecurity domain
|
||||
- [ ] Structured output from LLMs (JSON mode)
|
||||
- [ ] Feedback loop on guidance quality
|
||||
- [ ] Multi-modal support (images, documents)
|
||||
- [ ] Compliance reporting and audit trails
|
||||
- [ ] Performance optimization and caching
|
||||
|
||||
## Contributing
|
||||
|
||||
Follow the architecture and governance principles in `goose-core`. All changes must:
|
||||
- Adhere to agent policy (read-only, advisory only)
|
||||
- Conform to shared terminology in goose-core
|
||||
- Include appropriate documentation
|
||||
- Pass tests and lint checks
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
1. Check [INTEGRATION_GUIDE.md](INTEGRATION_GUIDE.md)
|
||||
2. Review [AGENT_IMPLEMENTATION.md](AGENT_IMPLEMENTATION.md)
|
||||
3. See API docs at http://localhost:8000/docs
|
||||
4. Check backend logs for errors
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Docker and Docker Compose
|
||||
- Python 3.11+ (for local development)
|
||||
- Node.js 18+ (for local development)
|
||||
|
||||
### Quick Start with Docker
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/mblanke/ThreatHunt.git
|
||||
cd ThreatHunt
|
||||
```
|
||||
|
||||
2. Start all services:
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
3. Access the application:
|
||||
- Frontend: http://localhost:3000
|
||||
- Backend API: http://localhost:8000
|
||||
- API Documentation: http://localhost:8000/docs
|
||||
|
||||
### Local Development
|
||||
|
||||
#### Backend
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Set up environment variables
|
||||
cp .env.example .env
|
||||
# Edit .env with your settings
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Start development server
|
||||
uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Authentication
|
||||
- `POST /api/auth/register` - Register a new user
|
||||
- `POST /api/auth/login` - Login and receive JWT token
|
||||
- `GET /api/auth/me` - Get current user profile
|
||||
- `PUT /api/auth/me` - Update current user profile
|
||||
|
||||
### User Management (Admin only)
|
||||
- `GET /api/users` - List all users in tenant
|
||||
- `GET /api/users/{user_id}` - Get user by ID
|
||||
- `PUT /api/users/{user_id}` - Update user
|
||||
- `DELETE /api/users/{user_id}` - Deactivate user
|
||||
|
||||
### Tenants
|
||||
- `GET /api/tenants` - List tenants
|
||||
- `POST /api/tenants` - Create tenant (admin)
|
||||
- `GET /api/tenants/{tenant_id}` - Get tenant by ID
|
||||
|
||||
### Hosts
|
||||
- `GET /api/hosts` - List hosts (scoped to tenant)
|
||||
- `POST /api/hosts` - Create host
|
||||
- `GET /api/hosts/{host_id}` - Get host by ID
|
||||
|
||||
### Ingestion
|
||||
- `POST /api/ingestion/ingest` - Upload and parse CSV files exported from Velociraptor
|
||||
|
||||
### VirusTotal
|
||||
- `POST /api/vt/lookup` - Lookup hash in VirusTotal
|
||||
|
||||
## Authentication Flow
|
||||
|
||||
1. User registers or logs in via `/api/auth/login`
|
||||
2. Backend returns JWT token with user_id, tenant_id, and role
|
||||
3. Frontend stores token in localStorage
|
||||
4. All subsequent API requests include token in Authorization header
|
||||
5. Backend validates token and enforces tenant scoping
|
||||
|
||||
## Multi-Tenancy
|
||||
|
||||
- All data is scoped to tenant_id
|
||||
- Users can only access data within their tenant
|
||||
- Admin users have elevated permissions within their tenant
|
||||
- Cross-tenant access requires explicit permissions
|
||||
|
||||
## Database Migrations
|
||||
|
||||
Create a new migration:
|
||||
```bash
|
||||
cd backend
|
||||
alembic revision --autogenerate -m "Description of changes"
|
||||
```
|
||||
|
||||
Apply migrations:
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Rollback migrations:
|
||||
```bash
|
||||
alembic downgrade -1
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Backend
|
||||
- `DATABASE_URL` - PostgreSQL connection string
|
||||
- `SECRET_KEY` - Secret key for JWT signing (min 32 characters)
|
||||
- `ACCESS_TOKEN_EXPIRE_MINUTES` - JWT token expiration time (default: 30)
|
||||
- `VT_API_KEY` - VirusTotal API key for hash lookups
|
||||
|
||||
### Frontend
|
||||
- `REACT_APP_API_URL` - Backend API URL (default: http://localhost:8000)
|
||||
|
||||
## Security
|
||||
|
||||
- Passwords are hashed using bcrypt
|
||||
- JWT tokens include expiration time
|
||||
- All API endpoints (except login/register) require authentication
|
||||
- Role-based access control for admin operations
|
||||
- Data isolation through tenant scoping
|
||||
|
||||
## Testing
|
||||
|
||||
### Backend
|
||||
```bash
|
||||
cd backend
|
||||
pytest
|
||||
```
|
||||
|
||||
### Frontend
|
||||
```bash
|
||||
cd frontend
|
||||
npm test
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Submit a pull request
|
||||
|
||||
## License
|
||||
|
||||
[Your License Here]
|
||||
|
||||
## Support
|
||||
|
||||
For issues and questions, please open an issue on GitHub.
|
||||
|
||||
21
SKILLS/00-operating-model.md
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
# Operating Model
|
||||
|
||||
## Default cadence
|
||||
- Prefer iterative progress over big bangs.
|
||||
- Keep diffs small: target ≤ 300 changed lines per PR unless justified.
|
||||
- Update tests/docs as part of the same change when possible.
|
||||
|
||||
## Working agreement
|
||||
- Start with a PLAN for non-trivial tasks.
|
||||
- Implement the smallest slice that satisfies acceptance criteria.
|
||||
- Verify via DoD.
|
||||
- Write a crisp PR summary: what changed, why, and how verified.
|
||||
|
||||
## Stop conditions (plan first)
|
||||
Stop and produce a PLAN (do not code yet) if:
|
||||
- scope is unclear
|
||||
- more than 3 files will change
|
||||
- data model changes
|
||||
- auth/security boundaries
|
||||
- performance-critical paths
|
||||
36
SKILLS/05-agent-taxonomy.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Agent Types & Roles (Practical Taxonomy)
|
||||
|
||||
Use this skill to choose the *right* kind of agent workflow for the job.
|
||||
|
||||
## Common agent "types" (in practice)
|
||||
|
||||
### 1) Chat assistant (no tools)
|
||||
Best for: explanations, brainstorming, small edits.
|
||||
Risk: can hallucinate; no grounding in repo state.
|
||||
|
||||
### 2) Tool-using single agent
|
||||
Best for: well-scoped tasks where the agent can read/write files and run commands.
|
||||
Key control: strict DoD gates + minimal permissions.
|
||||
|
||||
### 3) Planner + Executor (2-role pattern)
|
||||
Best for: medium complexity work (multi-file changes, feature work).
|
||||
Flow: Planner writes plan + acceptance criteria → Executor implements → Reviewer checks.
|
||||
|
||||
### 4) Multi-agent (specialists)
|
||||
Best for: bigger features with separable workstreams (UI, backend, docs, tests).
|
||||
Rule: isolate context per role; use separate branches/worktrees.
|
||||
|
||||
### 5) Supervisor / orchestrator
|
||||
Best for: long-running workflows with checkpoints (pipelines, report generation, PAD docs).
|
||||
Rule: supervisor delegates, enforces gates, and composes final output.
|
||||
|
||||
## Decision rules (fast)
|
||||
- If you can describe it in ≤ 5 steps → single tool-using agent.
|
||||
- If you need tradeoffs/design → Planner + Executor.
|
||||
- If UI + backend + docs/tests all move → multi-agent specialists.
|
||||
- If it's a pipeline that runs repeatedly → orchestrator.
|
||||
|
||||
## Guardrails (always)
|
||||
- DoD is the truth gate.
|
||||
- Separate branches/worktrees for parallel work.
|
||||
- Log decisions + commands in AGENT_LOG.md.
|
||||
24
SKILLS/10-definition-of-done.md
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
# Definition of Done (DoD)
|
||||
|
||||
A change is "done" only when:
|
||||
|
||||
## Code correctness
|
||||
- Builds successfully (if applicable)
|
||||
- Tests pass
|
||||
- Linting/formatting passes
|
||||
- Types/checks pass (if applicable)
|
||||
|
||||
## Quality
|
||||
- No new warnings introduced
|
||||
- Edge cases handled (inputs validated, errors meaningful)
|
||||
- Hot paths not regressed (if applicable)
|
||||
|
||||
## Hygiene
|
||||
- No secrets committed
|
||||
- Docs updated if behavior or usage changed
|
||||
- PR summary includes verification steps
|
||||
|
||||
## Commands
|
||||
- macOS/Linux: `./scripts/dod.sh`
|
||||
- Windows: `\scripts\dod.ps1`
|
||||
16
SKILLS/20-repo-map.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# Repo Mapping Skill
|
||||
|
||||
When entering a repo:
|
||||
1) Read README.md
|
||||
2) Identify entrypoints (app main / server startup / CLI)
|
||||
3) Identify config (env vars, .env.example, config files)
|
||||
4) Identify test/lint scripts (package.json, pyproject.toml, Makefile, etc.)
|
||||
5) Write a 10-line "repo map" in the PLAN before changing code
|
||||
|
||||
Output format:
|
||||
- Purpose:
|
||||
- Key modules:
|
||||
- Data flow:
|
||||
- Commands:
|
||||
- Risks:
|
||||
20
SKILLS/25-algorithms-performance.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Algorithms & Performance
|
||||
|
||||
Use this skill when performance matters (large inputs, hot paths, or repeated calls).
|
||||
|
||||
## Checklist
|
||||
- Identify the **state** you're recomputing.
|
||||
- Add **memoization / caching** when the same subproblem repeats.
|
||||
- Prefer **linear scans** + caches over nested loops when possible.
|
||||
- If you can write it as a **recurrence**, you can test it.
|
||||
|
||||
## Practical heuristics
|
||||
- Measure first when possible (timing + input sizes).
|
||||
- Optimize the biggest wins: avoid repeated I/O, repeated parsing, repeated network calls.
|
||||
- Keep caches bounded (size/TTL) and invalidate safely.
|
||||
- Choose data structures intentionally: dict/set for membership, heap for top-k, deque for queues.
|
||||
|
||||
## Review notes (for PRs)
|
||||
- Call out accidental O(n²) patterns.
|
||||
- Suggest table/DP or memoization when repeated work is obvious.
|
||||
- Add tests that cover base cases + typical cases + worst-case size.
|
||||
31
SKILLS/26-vibe-coding-fundamentals.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Vibe Coding With Fundamentals (Safety Rails)
|
||||
|
||||
Use this skill when you're using "vibe coding" (fast, conversational building) but want production-grade outcomes.
|
||||
|
||||
## The good
|
||||
- Rapid scaffolding and iteration
|
||||
- Fast UI prototypes
|
||||
- Quick exploration of architectures and options
|
||||
|
||||
## The failure mode
|
||||
- "It works on my machine" code with weak tests
|
||||
- Security foot-guns (auth, input validation, secrets)
|
||||
- Performance cliffs (accidental O(n²), repeated I/O)
|
||||
- Unmaintainable abstractions
|
||||
|
||||
## Safety rails (apply every time)
|
||||
- Always start with acceptance criteria (what "done" means).
|
||||
- Prefer small PRs; never dump a huge AI diff.
|
||||
- Require DoD gates (lint/test/build) before merge.
|
||||
- Write tests for behavior changes.
|
||||
- For anything security/data related: do a Reviewer pass.
|
||||
|
||||
## When to slow down
|
||||
- Auth/session/token work
|
||||
- Anything touching payments, PII, secrets
|
||||
- Data migrations/schema changes
|
||||
- Performance-critical paths
|
||||
- "It's flaky" or "it only fails in CI"
|
||||
|
||||
## Practical prompt pattern (use in PLAN)
|
||||
- "State assumptions, list files to touch, propose tests, and include rollback steps."
|
||||
31
SKILLS/27-performance-profiling.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Performance Profiling (Bun/Node)
|
||||
|
||||
Use this skill when:
|
||||
- a hot path feels slow
|
||||
- CPU usage is high
|
||||
- you suspect accidental O(n²) or repeated work
|
||||
- you need evidence before optimizing
|
||||
|
||||
## Bun CPU profiling
|
||||
Bun supports CPU profiling via `--cpu-prof` (generates a `.cpuprofile` you can open in Chrome DevTools).
|
||||
|
||||
Upcoming: `bun --cpu-prof-md <script>` outputs a CPU profile as **Markdown** so LLMs can read/grep it easily.
|
||||
|
||||
### Workflow (Bun)
|
||||
1) Run the workload with profiling enabled
|
||||
- Today: `bun --cpu-prof ./path/to/script.ts`
|
||||
- Upcoming: `bun --cpu-prof-md ./path/to/script.ts`
|
||||
2) Save the output (or `.cpuprofile`) into `./profiles/` with a timestamp.
|
||||
3) Ask the Reviewer agent to:
|
||||
- identify the top 5 hottest functions
|
||||
- propose the smallest fix
|
||||
- add a regression test or benchmark
|
||||
|
||||
## Node CPU profiling (fallback)
|
||||
- `node --cpu-prof ./script.js` writes a `.cpuprofile` file.
|
||||
- Open in Chrome DevTools → Performance → Load profile.
|
||||
|
||||
## Rules
|
||||
- Optimize based on measured hotspots, not vibes.
|
||||
- Prefer algorithmic wins (remove repeated work) over micro-optimizations.
|
||||
- Keep profiling artifacts out of git unless explicitly needed (use `.gitignore`).
|
||||
16
SKILLS/30-implementation-rules.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# Implementation Rules
|
||||
|
||||
## Change policy
|
||||
- Prefer edits over rewrites.
|
||||
- Keep changes localized.
|
||||
- One change = one purpose.
|
||||
- Avoid unnecessary abstraction.
|
||||
|
||||
## Dependency policy
|
||||
- Default: do not add dependencies.
|
||||
- If adding: explain why, alternatives considered, and impact.
|
||||
|
||||
## Error handling
|
||||
- Validate inputs at boundaries.
|
||||
- Error messages must be actionable: what failed + what to do next.
|
||||
14
SKILLS/40-testing-quality.md
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
# Testing & Quality
|
||||
|
||||
## Strategy
|
||||
- If behavior changes: add/update tests.
|
||||
- Unit tests for logic; integration tests for boundaries; E2E only where needed.
|
||||
|
||||
## Minimum for every PR
|
||||
- A test plan in the PR summary (even if "existing tests cover this").
|
||||
- Run DoD.
|
||||
|
||||
## Flaky tests
|
||||
- Capture repro steps.
|
||||
- Quarantine only with justification + follow-up issue.
|
||||
16
SKILLS/50-pr-review.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# PR Review Skill
|
||||
|
||||
Reviewer must check:
|
||||
- Correctness: does it do what it claims?
|
||||
- Safety: secrets, injection, auth boundaries
|
||||
- Maintainability: readability, naming, duplication
|
||||
- Tests: added/updated appropriately
|
||||
- DoD: did it pass?
|
||||
|
||||
Reviewer output format:
|
||||
1) Summary
|
||||
2) Must-fix
|
||||
3) Nice-to-have
|
||||
4) Risks
|
||||
5) Verification suggestions
|
||||
41
SKILLS/56-ui-material-ui.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Material UI (MUI) Design System
|
||||
|
||||
Use this skill for any React/Next "portal/admin/dashboard" UI so you stay consistent and avoid random component soup.
|
||||
|
||||
## Standard choice
|
||||
- Preferred UI library: **MUI (Material UI)**.
|
||||
- Prefer MUI components over ad-hoc HTML/CSS unless there's a good reason.
|
||||
- One design system per repo (do not mix Chakra/Ant/Bootstrap/etc.).
|
||||
|
||||
## Setup (Next.js/React)
|
||||
- Install: `@mui/material @emotion/react @emotion/styled`
|
||||
- If using icons: `@mui/icons-material`
|
||||
- If using data grid: `@mui/x-data-grid` (or pro if licensed)
|
||||
|
||||
## Theming rules
|
||||
- Define a single theme (typography, spacing, palette) and reuse everywhere.
|
||||
- Use semantic colors (primary/secondary/error/warning/success/info), not hard-coded hex everywhere.
|
||||
- Prefer MUI's `sx` for small styling; use `styled()` for reusable components.
|
||||
|
||||
## "Portal" patterns (modals, popovers, menus)
|
||||
- Use MUI Dialog/Modal/Popover/Menu components instead of DIY portals.
|
||||
- Accessibility requirements:
|
||||
- Focus is trapped in Dialog/Modal.
|
||||
- Escape closes modal unless explicitly prevented.
|
||||
- All inputs have labels; buttons have clear text/aria-labels.
|
||||
- Keyboard navigation works end-to-end.
|
||||
|
||||
## Layout conventions (for portals)
|
||||
- Use: AppBar + Drawer (or NavigationRail equivalent) + main content.
|
||||
- Keep pages as composition of small components: Page → Sections → Widgets.
|
||||
- Keep forms consistent: FormControl + helper text + validation messages.
|
||||
|
||||
## Performance hygiene
|
||||
- Avoid re-render storms: memoize heavy lists; use virtualization for large tables (DataGrid).
|
||||
- Prefer server pagination for huge datasets.
|
||||
|
||||
## PR review checklist
|
||||
- Theme is used (no random styling).
|
||||
- Components are MUI where reasonable.
|
||||
- Modal/popover accessibility is correct.
|
||||
- No mixed UI libraries.
|
||||
15
SKILLS/60-security-safety.md
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
# Security & Safety
|
||||
|
||||
## Secrets
|
||||
- Never output secrets or tokens.
|
||||
- Never log sensitive inputs.
|
||||
- Never commit credentials.
|
||||
|
||||
## Inputs
|
||||
- Validate external inputs at boundaries.
|
||||
- Fail closed for auth/security decisions.
|
||||
|
||||
## Tooling
|
||||
- No destructive commands unless requested and scoped.
|
||||
- Prefer read-only operations first.
|
||||
13
SKILLS/70-docs-artifacts.md
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
# Docs & Artifacts
|
||||
|
||||
Update documentation when:
|
||||
- setup steps change
|
||||
- env vars change
|
||||
- endpoints/CLI behavior changes
|
||||
- data formats change
|
||||
|
||||
Docs standards:
|
||||
- Provide copy/paste commands
|
||||
- Provide expected outputs where helpful
|
||||
- Keep it short and accurate
|
||||
11
SKILLS/80-mcp-tools.md
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
# MCP Tools Skill (Optional)
|
||||
|
||||
If this repo defines MCP servers/tools:
|
||||
|
||||
Rules:
|
||||
- Tool calls must be explicit and logged.
|
||||
- Maintain an allowlist of tools; deny by default.
|
||||
- Every tool must have: purpose, inputs/outputs schema, examples, and tests.
|
||||
- Prefer idempotent tool operations.
|
||||
- Never add tools that can exfiltrate secrets without strict guards.
|
||||
51
SKILLS/82-mcp-server-design.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# MCP Server Design (Agent-First)
|
||||
|
||||
Build MCP servers like you're designing a UI for a non-human user.
|
||||
|
||||
This skill distills Phil Schmid's MCP server best practices into concrete repo rules.
|
||||
Source: "MCP is Not the Problem, It's your Server" (Jan 21, 2026).
|
||||
|
||||
## 1) Outcomes, not operations
|
||||
- Do **not** wrap REST endpoints 1:1 as tools.
|
||||
- Expose high-level, outcome-oriented tools.
|
||||
- Bad: `get_user`, `list_orders`, `get_order_status`
|
||||
- Good: `track_latest_order(email)` (server orchestrates internally)
|
||||
|
||||
## 2) Flatten arguments
|
||||
- Prefer top-level primitives + constrained enums.
|
||||
- Avoid nested `dict`/config objects (agents hallucinate keys).
|
||||
- Defaults reduce decision load.
|
||||
|
||||
## 3) Instructions are context
|
||||
- Tool docstrings are *instructions*:
|
||||
- when to use the tool
|
||||
- argument formatting rules
|
||||
- what the return means
|
||||
- Error strings are also context:
|
||||
- return actionable, self-correcting messages (not raw stack traces)
|
||||
|
||||
## 4) Curate ruthlessly
|
||||
- Aim for **5–15 tools** per server.
|
||||
- One server, one job. Split by persona if needed.
|
||||
- Delete unused tools. Don't dump raw data into context.
|
||||
|
||||
## 5) Name tools for discovery
|
||||
- Avoid generic names (`create_issue`).
|
||||
- Prefer `{service}_{action}_{resource}`:
|
||||
- `velociraptor_run_hunt`
|
||||
- `github_list_prs`
|
||||
- `slack_send_message`
|
||||
|
||||
## 6) Paginate large results
|
||||
- Always support `limit` (default ~20–50).
|
||||
- Return metadata: `has_more`, `next_offset`, `total_count`.
|
||||
- Never return hundreds of rows unbounded.
|
||||
|
||||
## Repo conventions
|
||||
- Put MCP tool specs in `mcp/` (schemas, examples, fixtures).
|
||||
- Provide at least 1 "golden path" example call per tool.
|
||||
- Add an eval that checks:
|
||||
- tool names follow discovery convention
|
||||
- args are flat + typed
|
||||
- responses are concise + stable
|
||||
- pagination works
|
||||
40
SKILLS/83-fastmcp-3-patterns.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# FastMCP 3 Patterns (Providers + Transforms)
|
||||
|
||||
Use this skill when you are building MCP servers in Python and want:
|
||||
- composable tool sets
|
||||
- per-user/per-session behavior
|
||||
- auth, versioning, observability, and long-running tasks
|
||||
|
||||
## Mental model (FastMCP 3)
|
||||
FastMCP 3 treats everything as three composable primitives:
|
||||
- **Components**: what you expose (tools, resources, prompts)
|
||||
- **Providers**: where components come from (decorators, files, OpenAPI, remote MCP, etc.)
|
||||
- **Transforms**: how you reshape what clients see (namespace, filters, auth, versioning, visibility)
|
||||
|
||||
## Recommended architecture for Marc's platform
|
||||
Build a **single "Cyber MCP Gateway"** that composes providers:
|
||||
- LocalProvider: core cyber tools (run hunt, parse triage, generate report)
|
||||
- OpenAPIProvider: wrap stable internal APIs (ticketing, asset DB) without 1:1 endpoint exposure
|
||||
- ProxyProvider/FastMCPProvider: mount sub-servers (e.g., Velociraptor tools, Intel feeds)
|
||||
|
||||
Then apply transforms:
|
||||
- Namespace per domain: `hunt.*`, `intel.*`, `pad.*`
|
||||
- Visibility per session: hide dangerous tools unless user/role allows
|
||||
- VersionFilter: keep old clients working while you evolve tools
|
||||
|
||||
## Production must-haves
|
||||
- **Tool timeouts**: never let a tool hang forever
|
||||
- **Pagination**: all list tools must be bounded
|
||||
- **Background tasks**: use for long hunts / ingest jobs
|
||||
- **Tracing**: emit OpenTelemetry traces so you can debug agent/tool behavior
|
||||
|
||||
## Auth rules
|
||||
- Prefer component-level auth for "dangerous" tools.
|
||||
- Default stance: read-only tools visible; write/execute tools gated.
|
||||
|
||||
## Versioning rules
|
||||
- Version your components when you change schemas or semantics.
|
||||
- Keep 1 previous version callable during migrations.
|
||||
|
||||
## Upgrade guidance
|
||||
FastMCP 3 is in beta; pin to v2 for stability in production until you've tested.
|
||||
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./threathunt.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
67
backend/alembic/env.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Alembic async env — autogenerate from app.db.models."""
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Alembic Config
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import all models so autogenerate sees them
|
||||
from app.db.engine import Base # noqa: E402
|
||||
from app.db import models as _models # noqa: E402, F401
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True, # required for SQLite ALTER TABLE
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with an async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
210
backend/alembic/versions/9790f482da06_initial_schema.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""initial schema
|
||||
|
||||
Revision ID: 9790f482da06
|
||||
Revises:
|
||||
Create Date: 2026-02-19 11:40:02.108830
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9790f482da06'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('username', sa.String(length=64), nullable=False),
|
||||
sa.Column('email', sa.String(length=256), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=256), nullable=False),
|
||||
sa.Column('role', sa.String(length=16), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_users_username'), ['username'], unique=True)
|
||||
|
||||
op.create_table('hunts',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=256), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('owner_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('datasets',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=256), nullable=False),
|
||||
sa.Column('filename', sa.String(length=512), nullable=False),
|
||||
sa.Column('source_tool', sa.String(length=64), nullable=True),
|
||||
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||
sa.Column('column_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('normalized_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('ioc_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=False),
|
||||
sa.Column('encoding', sa.String(length=32), nullable=True),
|
||||
sa.Column('delimiter', sa.String(length=4), nullable=True),
|
||||
sa.Column('time_range_start', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('time_range_end', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('uploaded_by', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_datasets_hunt', ['hunt_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_datasets_name'), ['name'], unique=False)
|
||||
|
||||
op.create_table('hypotheses',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('title', sa.String(length=256), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('mitre_technique', sa.String(length=32), nullable=True),
|
||||
sa.Column('status', sa.String(length=16), nullable=False),
|
||||
sa.Column('evidence_row_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('evidence_notes', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('hypotheses', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_hypotheses_hunt', ['hunt_id'], unique=False)
|
||||
|
||||
op.create_table('conversations',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('title', sa.String(length=256), nullable=True),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('dataset_rows',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('row_index', sa.Integer(), nullable=False),
|
||||
sa.Column('data', sa.JSON(), nullable=False),
|
||||
sa.Column('normalized_data', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('dataset_rows', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_dataset_rows_dataset', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('ix_dataset_rows_dataset_idx', ['dataset_id', 'row_index'], unique=False)
|
||||
|
||||
op.create_table('enrichment_results',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('ioc_value', sa.String(length=512), nullable=False),
|
||||
sa.Column('ioc_type', sa.String(length=32), nullable=False),
|
||||
sa.Column('source', sa.String(length=32), nullable=False),
|
||||
sa.Column('verdict', sa.String(length=16), nullable=True),
|
||||
sa.Column('confidence', sa.Float(), nullable=True),
|
||||
sa.Column('raw_result', sa.JSON(), nullable=True),
|
||||
sa.Column('summary', sa.Text(), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('cached_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('enrichment_results', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_enrichment_ioc_source', ['ioc_value', 'source'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_enrichment_results_ioc_value'), ['ioc_value'], unique=False)
|
||||
|
||||
op.create_table('annotations',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('row_id', sa.Integer(), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('author_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('text', sa.Text(), nullable=False),
|
||||
sa.Column('severity', sa.String(length=16), nullable=False),
|
||||
sa.Column('tag', sa.String(length=32), nullable=True),
|
||||
sa.Column('highlight_color', sa.String(length=16), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['author_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.ForeignKeyConstraint(['row_id'], ['dataset_rows.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('annotations', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_annotations_dataset', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('ix_annotations_row', ['row_id'], unique=False)
|
||||
|
||||
op.create_table('messages',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('conversation_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('role', sa.String(length=16), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('model_used', sa.String(length=128), nullable=True),
|
||||
sa.Column('node_used', sa.String(length=64), nullable=True),
|
||||
sa.Column('token_count', sa.Integer(), nullable=True),
|
||||
sa.Column('latency_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('response_meta', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_messages_conversation', ['conversation_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_messages_conversation')
|
||||
|
||||
op.drop_table('messages')
|
||||
with op.batch_alter_table('annotations', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_annotations_row')
|
||||
batch_op.drop_index('ix_annotations_dataset')
|
||||
|
||||
op.drop_table('annotations')
|
||||
with op.batch_alter_table('enrichment_results', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_enrichment_results_ioc_value'))
|
||||
batch_op.drop_index('ix_enrichment_ioc_source')
|
||||
|
||||
op.drop_table('enrichment_results')
|
||||
with op.batch_alter_table('dataset_rows', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_dataset_rows_dataset_idx')
|
||||
batch_op.drop_index('ix_dataset_rows_dataset')
|
||||
|
||||
op.drop_table('dataset_rows')
|
||||
op.drop_table('conversations')
|
||||
with op.batch_alter_table('hypotheses', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_hypotheses_hunt')
|
||||
|
||||
op.drop_table('hypotheses')
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_datasets_name'))
|
||||
batch_op.drop_index('ix_datasets_hunt')
|
||||
|
||||
op.drop_table('datasets')
|
||||
op.drop_table('hunts')
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_users_username'))
|
||||
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,64 @@
|
||||
"""add_keyword_themes_and_keywords_tables
|
||||
|
||||
Revision ID: 98ab619418bc
|
||||
Revises: 9790f482da06
|
||||
Create Date: 2026-02-19 12:01:38.174653
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '98ab619418bc'
|
||||
down_revision: Union[str, Sequence[str], None] = '9790f482da06'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('keyword_themes',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=128), nullable=False),
|
||||
sa.Column('color', sa.String(length=16), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_builtin', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('keyword_themes', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_keyword_themes_name'), ['name'], unique=True)
|
||||
|
||||
op.create_table('keywords',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('theme_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('value', sa.String(length=256), nullable=False),
|
||||
sa.Column('is_regex', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['theme_id'], ['keyword_themes.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('keywords', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_keywords_theme', ['theme_id'], unique=False)
|
||||
batch_op.create_index('ix_keywords_value', ['value'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('keywords', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_keywords_value')
|
||||
batch_op.drop_index('ix_keywords_theme')
|
||||
|
||||
op.drop_table('keywords')
|
||||
with op.batch_alter_table('keyword_themes', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_keyword_themes_name'))
|
||||
|
||||
op.drop_table('keyword_themes')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,112 @@
|
||||
"""add processing_status and AI analysis tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 98ab619418bc
|
||||
Create Date: 2026-02-19 18:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, Sequence[str], None] = "98ab619418bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add columns to datasets table
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.add_column(sa.Column("processing_status", sa.String(20), server_default="ready"))
|
||||
batch_op.add_column(sa.Column("artifact_type", sa.String(128), nullable=True))
|
||||
batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True))
|
||||
batch_op.add_column(sa.Column("file_path", sa.String(512), nullable=True))
|
||||
batch_op.create_index("ix_datasets_status", ["processing_status"])
|
||||
|
||||
# Create triage_results table
|
||||
op.create_table(
|
||||
"triage_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_start", sa.Integer(), nullable=False),
|
||||
sa.Column("row_end", sa.Integer(), nullable=False),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("verdict", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("suspicious_indicators", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create host_profiles table
|
||||
op.create_table(
|
||||
"host_profiles",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("hostname", sa.String(256), nullable=False),
|
||||
sa.Column("fqdn", sa.String(512), nullable=True),
|
||||
sa.Column("client_id", sa.String(64), nullable=True),
|
||||
sa.Column("risk_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("risk_level", sa.String(20), nullable=False, server_default="unknown"),
|
||||
sa.Column("artifact_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("timeline_summary", sa.Text(), nullable=True),
|
||||
sa.Column("suspicious_findings", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON(), nullable=True),
|
||||
sa.Column("llm_analysis", sa.Text(), nullable=True),
|
||||
sa.Column("model_used", sa.String(128), nullable=True),
|
||||
sa.Column("node_used", sa.String(64), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create hunt_reports table
|
||||
op.create_table(
|
||||
"hunt_reports",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="pending"),
|
||||
sa.Column("exec_summary", sa.Text(), nullable=True),
|
||||
sa.Column("full_report", sa.Text(), nullable=True),
|
||||
sa.Column("findings", sa.JSON(), nullable=True),
|
||||
sa.Column("recommendations", sa.JSON(), nullable=True),
|
||||
sa.Column("mitre_mapping", sa.JSON(), nullable=True),
|
||||
sa.Column("ioc_table", sa.JSON(), nullable=True),
|
||||
sa.Column("host_risk_summary", sa.JSON(), nullable=True),
|
||||
sa.Column("models_used", sa.JSON(), nullable=True),
|
||||
sa.Column("generation_time_ms", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create anomaly_results table
|
||||
op.create_table(
|
||||
"anomaly_results",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("row_id", sa.String(32), sa.ForeignKey("dataset_rows.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("anomaly_score", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("distance_from_centroid", sa.Float(), nullable=True),
|
||||
sa.Column("cluster_id", sa.Integer(), nullable=True),
|
||||
sa.Column("is_outlier", sa.Boolean(), nullable=False, server_default="0"),
|
||||
sa.Column("explanation", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("anomaly_results")
|
||||
op.drop_table("hunt_reports")
|
||||
op.drop_table("host_profiles")
|
||||
op.drop_table("triage_results")
|
||||
|
||||
with op.batch_alter_table("datasets") as batch_op:
|
||||
batch_op.drop_index("ix_datasets_status")
|
||||
batch_op.drop_column("file_path")
|
||||
batch_op.drop_column("error_message")
|
||||
batch_op.drop_column("artifact_type")
|
||||
batch_op.drop_column("processing_status")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""add cases and activity logs
|
||||
|
||||
Revision ID: a3b1c2d4e5f6
|
||||
Revises: 98ab619418bc
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "a3b1c2d4e5f6"
|
||||
down_revision: Union[str, None] = "98ab619418bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cases",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("title", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("severity", sa.String(16), server_default="medium"),
|
||||
sa.Column("tlp", sa.String(16), server_default="amber"),
|
||||
sa.Column("pap", sa.String(16), server_default="amber"),
|
||||
sa.Column("status", sa.String(24), server_default="open"),
|
||||
sa.Column("priority", sa.Integer, server_default="2"),
|
||||
sa.Column("assignee", sa.String(128), nullable=True),
|
||||
sa.Column("tags", sa.JSON, nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("owner_id", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("mitre_techniques", sa.JSON, nullable=True),
|
||||
sa.Column("iocs", sa.JSON, nullable=True),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
op.create_index("ix_cases_hunt", "cases", ["hunt_id"])
|
||||
op.create_index("ix_cases_status", "cases", ["status"])
|
||||
|
||||
op.create_table(
|
||||
"case_tasks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("title", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("status", sa.String(24), server_default="todo"),
|
||||
sa.Column("assignee", sa.String(128), nullable=True),
|
||||
sa.Column("order", sa.Integer, server_default="0"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
op.create_index("ix_case_tasks_case", "case_tasks", ["case_id"])
|
||||
|
||||
op.create_table(
|
||||
"activity_logs",
|
||||
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
||||
sa.Column("entity_type", sa.String(32), nullable=False),
|
||||
sa.Column("entity_id", sa.String(32), nullable=False),
|
||||
sa.Column("action", sa.String(64), nullable=False),
|
||||
sa.Column("details", sa.JSON, nullable=True),
|
||||
sa.Column("user_id", sa.String(32), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
op.create_index("ix_activity_entity", "activity_logs", ["entity_type", "entity_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("activity_logs")
|
||||
op.drop_table("case_tasks")
|
||||
op.drop_table("cases")
|
||||
@@ -0,0 +1,63 @@
|
||||
"""add alerts and alert_rules tables
|
||||
|
||||
Revision ID: b4c2d3e5f6a7
|
||||
Revises: a3b1c2d4e5f6
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision: str = "b4c2d3e5f6a7"
|
||||
down_revision: Union[str, None] = "a3b1c2d4e5f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"alerts",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("title", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("severity", sa.String(16), server_default="medium"),
|
||||
sa.Column("status", sa.String(24), server_default="new"),
|
||||
sa.Column("analyzer", sa.String(64), nullable=False),
|
||||
sa.Column("score", sa.Float, server_default="0"),
|
||||
sa.Column("evidence", sa.JSON, nullable=True),
|
||||
sa.Column("mitre_technique", sa.String(32), nullable=True),
|
||||
sa.Column("tags", sa.JSON, nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id"), nullable=True),
|
||||
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
|
||||
sa.Column("assignee", sa.String(128), nullable=True),
|
||||
sa.Column("acknowledged_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_alerts_severity", "alerts", ["severity"])
|
||||
op.create_index("ix_alerts_status", "alerts", ["status"])
|
||||
op.create_index("ix_alerts_hunt", "alerts", ["hunt_id"])
|
||||
op.create_index("ix_alerts_dataset", "alerts", ["dataset_id"])
|
||||
|
||||
op.create_table(
|
||||
"alert_rules",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("name", sa.String(256), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("analyzer", sa.String(64), nullable=False),
|
||||
sa.Column("config", sa.JSON, nullable=True),
|
||||
sa.Column("severity_override", sa.String(16), nullable=True),
|
||||
sa.Column("enabled", sa.Boolean, server_default=sa.text("1")),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_alert_rules_analyzer", "alert_rules", ["analyzer"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("alert_rules")
|
||||
op.drop_table("alerts")
|
||||
@@ -0,0 +1,54 @@
|
||||
"""add notebooks and playbook_runs tables
|
||||
|
||||
Revision ID: c5d3e4f6a7b8
|
||||
Revises: b4c2d3e5f6a7
|
||||
Create Date: 2025-01-01 00:00:00.000000
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "c5d3e4f6a7b8"
|
||||
down_revision: Union[str, None] = "b4c2d3e5f6a7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notebooks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("title", sa.String(512), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=True),
|
||||
sa.Column("cells", sa.JSON, nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
|
||||
sa.Column("owner_id", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("tags", sa.JSON, nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_notebooks_hunt", "notebooks", ["hunt_id"])
|
||||
|
||||
op.create_table(
|
||||
"playbook_runs",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("playbook_name", sa.String(256), nullable=False),
|
||||
sa.Column("status", sa.String(24), server_default="in-progress"),
|
||||
sa.Column("current_step", sa.Integer, server_default="1"),
|
||||
sa.Column("total_steps", sa.Integer, server_default="0"),
|
||||
sa.Column("step_results", sa.JSON, nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("case_id", sa.String(32), sa.ForeignKey("cases.id"), nullable=True),
|
||||
sa.Column("started_by", sa.String(128), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.create_index("ix_playbook_runs_hunt", "playbook_runs", ["hunt_id"])
|
||||
op.create_index("ix_playbook_runs_status", "playbook_runs", ["status"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("playbook_runs")
|
||||
op.drop_table("notebooks")
|
||||
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend initialization."""
|
||||
67
backend/app/agent/debate.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
|
||||
async def debated_generate(provider, prompt: str) -> str:
|
||||
"""
|
||||
Minimal behind-the-scenes debate.
|
||||
Same logic for all apps.
|
||||
Advisory only. No execution.
|
||||
"""
|
||||
|
||||
planner = f"""
|
||||
You are the Planner.
|
||||
Give structured advisory guidance only.
|
||||
No execution. No tools.
|
||||
|
||||
Request:
|
||||
{prompt}
|
||||
"""
|
||||
|
||||
critic = f"""
|
||||
You are the Critic.
|
||||
Identify risks, missing steps, and assumptions.
|
||||
No execution. No tools.
|
||||
|
||||
Request:
|
||||
{prompt}
|
||||
"""
|
||||
|
||||
pragmatist = f"""
|
||||
You are the Pragmatist.
|
||||
Suggest the safest and simplest approach.
|
||||
No execution. No tools.
|
||||
|
||||
Request:
|
||||
{prompt}
|
||||
"""
|
||||
|
||||
planner_task = provider.generate(planner)
|
||||
critic_task = provider.generate(critic)
|
||||
prag_task = provider.generate(pragmatist)
|
||||
|
||||
planner_resp, critic_resp, prag_resp = await asyncio.gather(
|
||||
planner_task, critic_task, prag_task
|
||||
)
|
||||
|
||||
judge = f"""
|
||||
You are the Judge.
|
||||
|
||||
Merge the three responses into ONE final advisory answer.
|
||||
|
||||
Rules:
|
||||
- Advisory only
|
||||
- No execution
|
||||
- Clearly list risks and assumptions
|
||||
- Be concise
|
||||
|
||||
Planner:
|
||||
{planner_resp}
|
||||
|
||||
Critic:
|
||||
{critic_resp}
|
||||
|
||||
Pragmatist:
|
||||
{prag_resp}
|
||||
"""
|
||||
|
||||
final = await provider.generate(judge)
|
||||
return final
|
||||
16
backend/app/agents/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Analyst-assist agent module for ThreatHunt.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only and do not execute actions or modify data.
|
||||
"""
|
||||
|
||||
from .core import ThreatHuntAgent
|
||||
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider
|
||||
|
||||
__all__ = [
|
||||
"ThreatHuntAgent",
|
||||
"LLMProvider",
|
||||
"LocalProvider",
|
||||
"NetworkedProvider",
|
||||
"OnlineProvider",
|
||||
]
|
||||
59
backend/app/agents/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Configuration for agent settings."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class AgentConfig:
|
||||
"""Configuration for analyst-assist agents."""
|
||||
|
||||
# Provider type: 'local', 'networked', 'online', or 'auto'
|
||||
PROVIDER_TYPE: Literal["local", "networked", "online", "auto"] = os.getenv(
|
||||
"THREAT_HUNT_AGENT_PROVIDER", "auto"
|
||||
)
|
||||
|
||||
# Local provider settings
|
||||
LOCAL_MODEL_PATH: str | None = os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
|
||||
|
||||
# Networked provider settings
|
||||
NETWORKED_ENDPOINT: str | None = os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
|
||||
NETWORKED_API_KEY: str | None = os.getenv("THREAT_HUNT_NETWORKED_KEY")
|
||||
|
||||
# Online provider settings
|
||||
ONLINE_API_PROVIDER: str = os.getenv("THREAT_HUNT_ONLINE_PROVIDER", "openai")
|
||||
ONLINE_API_KEY: str | None = os.getenv("THREAT_HUNT_ONLINE_API_KEY")
|
||||
ONLINE_MODEL: str | None = os.getenv("THREAT_HUNT_ONLINE_MODEL")
|
||||
|
||||
# Agent behavior settings
|
||||
MAX_RESPONSE_TOKENS: int = int(
|
||||
os.getenv("THREAT_HUNT_AGENT_MAX_TOKENS", "1024")
|
||||
)
|
||||
ENABLE_REASONING: bool = os.getenv(
|
||||
"THREAT_HUNT_AGENT_REASONING", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
CONVERSATION_HISTORY_LENGTH: int = int(
|
||||
os.getenv("THREAT_HUNT_AGENT_HISTORY_LENGTH", "10")
|
||||
)
|
||||
|
||||
# Privacy settings
|
||||
FILTER_SENSITIVE_DATA: bool = os.getenv(
|
||||
"THREAT_HUNT_AGENT_FILTER_SENSITIVE", "true"
|
||||
).lower() in ("true", "1", "yes")
|
||||
|
||||
@classmethod
|
||||
def is_agent_enabled(cls) -> bool:
|
||||
"""Check if agent is enabled and properly configured."""
|
||||
# Agent is disabled if no provider can be used
|
||||
if cls.PROVIDER_TYPE == "auto":
|
||||
return bool(
|
||||
cls.LOCAL_MODEL_PATH
|
||||
or cls.NETWORKED_ENDPOINT
|
||||
or cls.ONLINE_API_KEY
|
||||
)
|
||||
elif cls.PROVIDER_TYPE == "local":
|
||||
return bool(cls.LOCAL_MODEL_PATH)
|
||||
elif cls.PROVIDER_TYPE == "networked":
|
||||
return bool(cls.NETWORKED_ENDPOINT)
|
||||
elif cls.PROVIDER_TYPE == "online":
|
||||
return bool(cls.ONLINE_API_KEY)
|
||||
return False
|
||||
208
backend/app/agents/core.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Core ThreatHunt analyst-assist agent.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only - no execution, no alerts, no data modifications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers import LLMProvider, get_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
|
||||
host_identifier: Optional[str] = Field(
|
||||
None, description="Host name, IP, or identifier"
|
||||
)
|
||||
data_summary: Optional[str] = Field(
|
||||
None, description="Brief description of uploaded data"
|
||||
)
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages in conversation"
|
||||
)
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
|
||||
)
|
||||
suggested_pivots: list[str] = Field(
|
||||
default_factory=list, description="Suggested analytical directions"
|
||||
)
|
||||
suggested_filters: list[str] = Field(
|
||||
default_factory=list, description="Suggested data filters or queries"
|
||||
)
|
||||
caveats: Optional[str] = Field(
|
||||
None, description="Assumptions, limitations, or caveats"
|
||||
)
|
||||
reasoning: Optional[str] = Field(
|
||||
None, description="Explanation of how guidance was generated"
|
||||
)
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent for ThreatHunt.
|
||||
|
||||
Provides guidance on:
|
||||
- Interpreting CSV artifact data
|
||||
- Suggesting analytical pivots and filters
|
||||
- Forming and testing hypotheses
|
||||
|
||||
Policy:
|
||||
- Advisory guidance only (no execution)
|
||||
- No database or schema changes
|
||||
- No alert escalation
|
||||
- Transparent reasoning
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Optional[LLMProvider] = None):
|
||||
"""Initialize agent with LLM provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider instance. If None, uses get_provider() with auto mode.
|
||||
"""
|
||||
if provider is None:
|
||||
try:
|
||||
provider = get_provider("auto")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize default provider: {e}")
|
||||
provider = None
|
||||
|
||||
self.provider = provider
|
||||
self.system_prompt = self._build_system_prompt()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt that governs agent behavior."""
|
||||
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You do NOT make autonomous decisions
|
||||
- You ONLY analyze data presented to you
|
||||
- You explain your reasoning transparently
|
||||
- You acknowledge limitations and assumptions
|
||||
- You suggest next investigative steps
|
||||
|
||||
When responding:
|
||||
1. Start with a clear, direct answer to the query
|
||||
2. Explain your reasoning based on the data context provided
|
||||
3. Suggest 2-4 analytical pivots the analyst might explore
|
||||
4. Suggest 2-4 data filters or queries that might be useful
|
||||
5. Include relevant caveats or assumptions
|
||||
6. Be honest about what you cannot determine from the data
|
||||
|
||||
Remember: The analyst is the decision-maker. You are an assistant."""
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis.
|
||||
|
||||
Args:
|
||||
context: Request context including query and data context.
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
if not self.provider:
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
|
||||
# Build prompt with context
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
try:
|
||||
# Get guidance from LLM provider
|
||||
guidance = await self.provider.generate(prompt, max_tokens=1024)
|
||||
|
||||
# Parse response into structured format
|
||||
response = self._parse_response(guidance, context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted with query: {context.query[:50]}... "
|
||||
f"(dataset: {context.dataset_name})"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating guidance: {e}")
|
||||
raise
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt for the LLM."""
|
||||
prompt_parts = [
|
||||
f"Analyst query: {context.query}",
|
||||
]
|
||||
|
||||
if context.dataset_name:
|
||||
prompt_parts.append(f"Dataset: {context.dataset_name}")
|
||||
|
||||
if context.artifact_type:
|
||||
prompt_parts.append(f"Artifact type: {context.artifact_type}")
|
||||
|
||||
if context.host_identifier:
|
||||
prompt_parts.append(f"Host: {context.host_identifier}")
|
||||
|
||||
if context.data_summary:
|
||||
prompt_parts.append(f"Data summary: {context.data_summary}")
|
||||
|
||||
if context.conversation_history:
|
||||
prompt_parts.append("\nConversation history:")
|
||||
for msg in context.conversation_history[-5:]: # Last 5 messages for context
|
||||
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM response into structured format.
|
||||
|
||||
Note: This is a simplified parser. In production, use structured output
|
||||
from the LLM (JSON mode, function calling, etc.) for better reliability.
|
||||
"""
|
||||
# For now, return a structured response based on the raw guidance
|
||||
# In production, parse JSON or use structured output from LLM
|
||||
return AgentResponse(
|
||||
guidance=response_text,
|
||||
confidence=0.8, # Placeholder
|
||||
suggested_pivots=[
|
||||
"Analyze temporal patterns",
|
||||
"Cross-reference with known indicators",
|
||||
"Examine outliers in the dataset",
|
||||
"Compare with baseline behavior",
|
||||
],
|
||||
suggested_filters=[
|
||||
"Filter by high-risk indicators",
|
||||
"Sort by timestamp for timeline analysis",
|
||||
"Group by host or user",
|
||||
"Filter by anomaly score",
|
||||
],
|
||||
caveats="Guidance is based on available data context. "
|
||||
"Analysts should verify findings with additional sources.",
|
||||
reasoning="Analysis generated based on artifact data patterns and analyst query.",
|
||||
)
|
||||
408
backend/app/agents/core_v2.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Core ThreatHunt analyst-assist agent — v2.
|
||||
|
||||
Uses TaskRouter to select the right model/node for each query,
|
||||
real LLM providers (Ollama/OpenWebUI), and structured response parsing.
|
||||
Integrates SANS RAG context from Open WebUI.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.services.sans_rag import sans_rag
|
||||
from .router import TaskRouter, TaskType, RoutingDecision, task_router
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(..., description="Analyst question or request for guidance")
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type")
|
||||
host_identifier: Optional[str] = Field(None, description="Host name, IP, or identifier")
|
||||
data_summary: Optional[str] = Field(None, description="Brief description of data")
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages"
|
||||
)
|
||||
active_hypotheses: Optional[list[str]] = Field(
|
||||
default_factory=list, description="Active investigation hypotheses"
|
||||
)
|
||||
annotations_summary: Optional[str] = Field(
|
||||
None, description="Summary of analyst annotations"
|
||||
)
|
||||
enrichment_summary: Optional[str] = Field(
|
||||
None, description="Summary of enrichment results"
|
||||
)
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: Optional[str] = Field(None, description="Force a specific model")
|
||||
|
||||
|
||||
class Perspective(BaseModel):
|
||||
"""A single perspective from the debate agent."""
|
||||
role: str
|
||||
content: str
|
||||
model_used: str
|
||||
node_used: str
|
||||
latency_ms: int
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence (0-1)")
|
||||
suggested_pivots: list[str] = Field(default_factory=list)
|
||||
suggested_filters: list[str] = Field(default_factory=list)
|
||||
caveats: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
sans_references: list[str] = Field(
|
||||
default_factory=list, description="SANS course references"
|
||||
)
|
||||
model_used: str = Field(default="", description="Model that generated the response")
|
||||
node_used: str = Field(default="", description="Node that processed the request")
|
||||
latency_ms: int = Field(default=0, description="Total latency in ms")
|
||||
perspectives: Optional[list[Perspective]] = Field(
|
||||
None, description="Debate perspectives (only in debate mode)"
|
||||
)
|
||||
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────
|
||||
|
||||
SYSTEM_PROMPT = """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
You have access to 300GB of SANS cybersecurity course material for reference.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor and other forensic tools
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Reference relevant SANS methodologies and techniques when applicable
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You explain your reasoning transparently
|
||||
|
||||
RESPONSE FORMAT — you MUST respond with valid JSON:
|
||||
{
|
||||
"guidance": "Your main guidance text here",
|
||||
"confidence": 0.85,
|
||||
"suggested_pivots": ["Pivot 1", "Pivot 2"],
|
||||
"suggested_filters": ["filter expression 1", "filter expression 2"],
|
||||
"caveats": "Any assumptions or limitations",
|
||||
"reasoning": "How you arrived at this guidance",
|
||||
"sans_references": ["SANS SEC504: ...", "SANS FOR508: ..."]
|
||||
}
|
||||
|
||||
Respond ONLY with the JSON object. No markdown, no code fences, no extra text."""
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent backed by Wile + Roadrunner LLM cluster."""
|
||||
|
||||
def __init__(self, router: TaskRouter | None = None):
|
||||
self.router = router or task_router
|
||||
self.system_prompt = SYSTEM_PROMPT
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis."""
|
||||
start = time.monotonic()
|
||||
|
||||
if context.mode == "debate":
|
||||
return await self._debate_assist(context)
|
||||
|
||||
# Classify task and route
|
||||
task_type = self.router.classify_task(context.query)
|
||||
if context.mode == "deep":
|
||||
task_type = TaskType.DEEP_ANALYSIS
|
||||
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
logger.info(f"Routing: {decision.reason}")
|
||||
|
||||
# Enrich prompt with SANS RAG context
|
||||
prompt = self._build_prompt(context)
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
context.query,
|
||||
investigation_context=context.data_summary or "",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
result = await provider.chat(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
|
||||
raw_text = result.get("response", "")
|
||||
latency_ms = result.get("_latency_ms", 0)
|
||||
|
||||
# Parse structured response
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = decision.model
|
||||
response.node_used = decision.node.value
|
||||
response.latency_ms = latency_ms
|
||||
|
||||
total_ms = int((time.monotonic() - start) * 1000)
|
||||
logger.info(
|
||||
f"Agent assist: {context.query[:60]}... → "
|
||||
f"{decision.model} on {decision.node.value} "
|
||||
f"({total_ms}ms total, {latency_ms}ms LLM)"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def assist_stream(
|
||||
self,
|
||||
context: AgentContext,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream agent response tokens."""
|
||||
task_type = self.router.classify_task(context.query)
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OllamaProvider):
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
elif isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
async for token in provider.chat_stream(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
|
||||
async def _debate_assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Multi-perspective analysis using diverse models on Wile."""
|
||||
import asyncio
|
||||
|
||||
start = time.monotonic()
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
# Route each perspective to a different heavy model
|
||||
roles = {
|
||||
TaskType.DEBATE_PLANNER: (
|
||||
"Planner",
|
||||
"You are the Planner for a threat hunting investigation.\n"
|
||||
"Provide a structured investigation strategy. Reference SANS methodologies.\n"
|
||||
"Focus on: investigation steps, data sources to examine, MITRE ATT&CK mapping.\n"
|
||||
"Be specific to the data context provided.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_CRITIC: (
|
||||
"Critic",
|
||||
"You are the Critic for a threat hunting investigation.\n"
|
||||
"Identify risks, false positive scenarios, missing evidence, and assumptions.\n"
|
||||
"Reference SANS training on common analyst mistakes.\n"
|
||||
"Challenge the obvious interpretation.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_PRAGMATIST: (
|
||||
"Pragmatist",
|
||||
"You are the Pragmatist for a threat hunting investigation.\n"
|
||||
"Suggest the most actionable, efficient next steps.\n"
|
||||
"Reference SANS incident response playbooks.\n"
|
||||
"Focus on: quick wins, triage priorities, what to escalate.\n\n",
|
||||
),
|
||||
}
|
||||
|
||||
async def _call_perspective(task_type: TaskType, role_name: str, prefix: str):
|
||||
decision = self.router.route(task_type)
|
||||
provider = self.router.get_provider(decision)
|
||||
full_prompt = prefix + prompt
|
||||
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
|
||||
return Perspective(
|
||||
role=role_name,
|
||||
content=result.get("response", ""),
|
||||
model_used=decision.model,
|
||||
node_used=decision.node.value,
|
||||
latency_ms=result.get("_latency_ms", 0),
|
||||
)
|
||||
|
||||
# Run perspectives in parallel
|
||||
perspective_tasks = [
|
||||
_call_perspective(tt, name, prefix)
|
||||
for tt, (name, prefix) in roles.items()
|
||||
]
|
||||
perspectives = await asyncio.gather(*perspective_tasks)
|
||||
|
||||
# Judge merges the perspectives
|
||||
judge_prompt = (
|
||||
"You are the Judge. Merge these three threat hunting perspectives into "
|
||||
"ONE final advisory answer.\n\n"
|
||||
"Rules:\n"
|
||||
"- Advisory only — no execution\n"
|
||||
"- Clearly list risks and assumptions\n"
|
||||
"- Highlight where perspectives agree and disagree\n"
|
||||
"- Provide a unified recommendation\n"
|
||||
"- Reference SANS methodologies where relevant\n\n"
|
||||
)
|
||||
for p in perspectives:
|
||||
judge_prompt += f"=== {p.role} (via {p.model_used}) ===\n{p.content}\n\n"
|
||||
|
||||
judge_prompt += (
|
||||
f"\nOriginal analyst query:\n{context.query}\n\n"
|
||||
"Respond with the merged analysis in this JSON format:\n"
|
||||
'{"guidance": "...", "confidence": 0.85, "suggested_pivots": [...], '
|
||||
'"suggested_filters": [...], "caveats": "...", "reasoning": "...", '
|
||||
'"sans_references": [...]}'
|
||||
)
|
||||
|
||||
judge_decision = self.router.route(TaskType.DEBATE_JUDGE)
|
||||
judge_provider = self.router.get_provider(judge_decision)
|
||||
|
||||
if isinstance(judge_provider, OpenWebUIProvider):
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
else:
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
raw_text = judge_result.get("response", "")
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = judge_decision.model
|
||||
response.node_used = judge_decision.node.value
|
||||
response.latency_ms = int((time.monotonic() - start) * 1000)
|
||||
response.perspectives = list(perspectives)
|
||||
|
||||
return response
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt with all available context."""
|
||||
parts = [f"Analyst query: {context.query}"]
|
||||
|
||||
if context.dataset_name:
|
||||
parts.append(f"Dataset: {context.dataset_name}")
|
||||
if context.artifact_type:
|
||||
parts.append(f"Artifact type: {context.artifact_type}")
|
||||
if context.host_identifier:
|
||||
parts.append(f"Host: {context.host_identifier}")
|
||||
if context.data_summary:
|
||||
parts.append(f"Data summary: {context.data_summary}")
|
||||
if context.active_hypotheses:
|
||||
parts.append(f"Active hypotheses: {'; '.join(context.active_hypotheses)}")
|
||||
if context.annotations_summary:
|
||||
parts.append(f"Analyst annotations: {context.annotations_summary}")
|
||||
if context.enrichment_summary:
|
||||
parts.append(f"Enrichment data: {context.enrichment_summary}")
|
||||
if context.conversation_history:
|
||||
parts.append("\nRecent conversation:")
|
||||
for msg in context.conversation_history[-settings.AGENT_HISTORY_LENGTH:]:
|
||||
parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')[:500]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _parse_response(self, raw: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM output into structured AgentResponse.
|
||||
|
||||
Tries JSON extraction first, falls back to raw text with defaults.
|
||||
"""
|
||||
parsed = self._try_parse_json(raw)
|
||||
if parsed:
|
||||
return AgentResponse(
|
||||
guidance=parsed.get("guidance", raw),
|
||||
confidence=min(max(float(parsed.get("confidence", 0.7)), 0.0), 1.0),
|
||||
suggested_pivots=parsed.get("suggested_pivots", [])[:6],
|
||||
suggested_filters=parsed.get("suggested_filters", [])[:6],
|
||||
caveats=parsed.get("caveats"),
|
||||
reasoning=parsed.get("reasoning"),
|
||||
sans_references=parsed.get("sans_references", []),
|
||||
)
|
||||
|
||||
# Fallback: use raw text as guidance
|
||||
return AgentResponse(
|
||||
guidance=raw.strip() or "No guidance generated. Please try rephrasing your question.",
|
||||
confidence=0.5,
|
||||
suggested_pivots=[],
|
||||
suggested_filters=[],
|
||||
caveats="Response was not in structured format. Pivots and filters may be embedded in the guidance text.",
|
||||
reasoning=None,
|
||||
sans_references=[],
|
||||
)
|
||||
|
||||
def _try_parse_json(self, text: str) -> dict | None:
|
||||
"""Try to extract JSON from LLM output."""
|
||||
# Direct parse
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Extract from code fences
|
||||
patterns = [
|
||||
r"```json\s*(.*?)\s*```",
|
||||
r"```\s*(.*?)\s*```",
|
||||
r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1) if match.lastindex else match.group(0))
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
190
backend/app/agents/providers.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Pluggable LLM provider interface for analyst-assist agents.
|
||||
|
||||
Supports three provider types:
|
||||
- Local: On-device or on-prem models
|
||||
- Networked: Shared internal inference services
|
||||
- Online: External hosted APIs
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate a response from the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt
|
||||
max_tokens: Maximum tokens in response
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider backend is available."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalProvider(LLMProvider):
|
||||
"""Local LLM provider (on-device or on-prem models)."""
|
||||
|
||||
def __init__(self, model_path: Optional[str] = None):
|
||||
"""Initialize local provider.
|
||||
|
||||
Args:
|
||||
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
|
||||
"""
|
||||
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
|
||||
self.model = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if local model is available."""
|
||||
if not self.model_path:
|
||||
return False
|
||||
# In production, would verify model file exists and can be loaded
|
||||
return os.path.exists(str(self.model_path))
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using local model.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- llama-cpp-python for GGML models
|
||||
- Ollama API
|
||||
- vLLM
|
||||
- Other local inference engines
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Local model not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Local model response to: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class NetworkedProvider(LLMProvider):
|
||||
"""Networked LLM provider (shared internal inference services)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "default",
|
||||
):
|
||||
"""Initialize networked provider.
|
||||
|
||||
Args:
|
||||
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
|
||||
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
|
||||
model_name: Model name/ID on the service.
|
||||
"""
|
||||
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
|
||||
self.model_name = model_name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if networked service is available."""
|
||||
return bool(self.api_endpoint)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using networked service.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- Internal inference service API
|
||||
- LLM inference container cluster
|
||||
- Enterprise inference gateway
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Networked service not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class OnlineProvider(LLMProvider):
|
||||
"""Online LLM provider (external hosted APIs)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: str = "openai",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize online provider.
|
||||
|
||||
Args:
|
||||
api_provider: Provider name (openai, anthropic, google, etc.)
|
||||
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
|
||||
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
|
||||
self.model_name = model_name or os.getenv(
|
||||
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if online API is available."""
|
||||
return bool(self.api_key)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using online API.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- OpenAI API (GPT-3.5, GPT-4, etc.)
|
||||
- Anthropic Claude API
|
||||
- Google Gemini API
|
||||
- Other hosted LLM services
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Online API not available or API key not set")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
|
||||
|
||||
|
||||
def get_provider(provider_type: str = "auto") -> LLMProvider:
|
||||
"""Get an LLM provider based on configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
|
||||
'auto' attempts to use the first available provider in order:
|
||||
local -> networked -> online.
|
||||
|
||||
Returns:
|
||||
Configured LLM provider instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
# Explicit provider selection
|
||||
if provider_type == "local":
|
||||
provider = LocalProvider()
|
||||
elif provider_type == "networked":
|
||||
provider = NetworkedProvider()
|
||||
elif provider_type == "online":
|
||||
provider = OnlineProvider()
|
||||
elif provider_type == "auto":
|
||||
# Try providers in order of preference
|
||||
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
|
||||
provider = Provider()
|
||||
if provider.is_available():
|
||||
return provider
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider.is_available():
|
||||
raise RuntimeError(f"{provider_type} provider not available")
|
||||
|
||||
return provider
|
||||
362
backend/app/agents/providers_v2.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
|
||||
|
||||
Three providers:
|
||||
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
|
||||
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
|
||||
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from .registry import ModelEntry, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared HTTP client with reasonable timeouts
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None or _client.is_closed:
|
||||
_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
|
||||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def cleanup_client():
|
||||
global _client
|
||||
if _client and not _client.is_closed:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
|
||||
|
||||
def _ollama_url(node: Node) -> str:
|
||||
"""Get the Ollama base URL for a node."""
|
||||
if node == Node.WILE:
|
||||
return settings.wile_url
|
||||
elif node == Node.ROADRUNNER:
|
||||
return settings.roadrunner_url
|
||||
else:
|
||||
raise ValueError(f"No direct Ollama URL for node: {node}")
|
||||
|
||||
|
||||
# ── Ollama Provider ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class OllamaProvider:
|
||||
"""Direct Ollama API calls to Wile or Roadrunner."""
|
||||
|
||||
def __init__(self, model: str, node: Node):
|
||||
self.model = model
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
data["_latency_ms"] = latency_ms
|
||||
data["_node"] = self.node.value
|
||||
logger.info(
|
||||
f"Ollama [{self.node.value}] {self.model}: "
|
||||
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
|
||||
)
|
||||
return data
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
|
||||
raise
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via Ollama /api/chat."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
|
||||
data["_node"] = self.node.value
|
||||
return data
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from Ollama."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with client.stream(
|
||||
"POST", f"{self.base_url}/api/generate", json=payload
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Ping the Ollama node."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
|
||||
|
||||
|
||||
class OpenWebUIProvider:
|
||||
"""Calls to Open WebUI cluster at ai.guapo613.beer.
|
||||
|
||||
Uses the OpenAI-compatible /v1/chat/completions endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = ""):
|
||||
self.model = model or settings.DEFAULT_FAST_MODEL
|
||||
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via OpenAI-compatible endpoint."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Normalize to our format
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
result = {
|
||||
"response": content,
|
||||
"model": data.get("model", self.model),
|
||||
"_latency_ms": latency_ms,
|
||||
"_node": "cluster",
|
||||
"_usage": data.get("usage", {}),
|
||||
}
|
||||
logger.info(
|
||||
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
|
||||
)
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Convert prompt-style call to chat format."""
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
return await self.chat(messages, max_tokens, temperature)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from OpenWebUI."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
if token:
|
||||
yield token
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Open WebUI is reachable."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Embedding Provider ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Generate embeddings via Ollama /api/embeddings."""
|
||||
|
||||
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
|
||||
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
client = _get_client()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("embedding", [])
|
||||
|
||||
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
|
||||
"""Embed multiple texts with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _embed_one(t: str) -> list[float]:
|
||||
async with sem:
|
||||
return await self.embed(t)
|
||||
|
||||
return await asyncio.gather(*[_embed_one(t) for t in texts])
|
||||
|
||||
|
||||
# ── Health check for all nodes ────────────────────────────────────────
|
||||
|
||||
|
||||
async def check_all_nodes() -> dict:
|
||||
"""Check availability of all LLM nodes."""
|
||||
wile = OllamaProvider("", Node.WILE)
|
||||
roadrunner = OllamaProvider("", Node.ROADRUNNER)
|
||||
cluster = OpenWebUIProvider()
|
||||
|
||||
wile_ok, rr_ok, cl_ok = await asyncio.gather(
|
||||
wile.is_available(),
|
||||
roadrunner.is_available(),
|
||||
cluster.is_available(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"wile": {"available": wile_ok is True, "url": settings.wile_url},
|
||||
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
|
||||
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
|
||||
}
|
||||
161
backend/app/agents/registry.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Model registry — inventory of all Ollama models across Wile and Roadrunner.
|
||||
|
||||
Each model is tagged with capabilities (chat, code, vision, embedding) and
|
||||
performance tier (fast, medium, heavy) for the TaskRouter.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Capability(str, Enum):
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class Tier(str, Enum):
|
||||
FAST = "fast" # < 15B params — quick responses
|
||||
MEDIUM = "medium" # 15–40B params — balanced
|
||||
HEAVY = "heavy" # 40B+ params — deep analysis
|
||||
|
||||
|
||||
class Node(str, Enum):
|
||||
WILE = "wile"
|
||||
ROADRUNNER = "roadrunner"
|
||||
CLUSTER = "cluster" # Open WebUI balances across both
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelEntry:
|
||||
name: str
|
||||
node: Node
|
||||
capabilities: list[Capability]
|
||||
tier: Tier
|
||||
param_size: str = "" # e.g. "7b", "70b"
|
||||
notes: str = ""
|
||||
|
||||
|
||||
# ── Roadrunner (100.110.190.11) ──────────────────────────────────────
|
||||
|
||||
ROADRUNNER_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("qwen2.5:14b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "14b"),
|
||||
ModelEntry("mistral:7b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("mistral:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("phi3:medium", Node.ROADRUNNER, [Capability.CHAT], Tier.MEDIUM, "14b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("codestral:latest", Node.ROADRUNNER, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
ModelEntry("codellama:13b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "13b"),
|
||||
# Vision
|
||||
ModelEntry("llama3.2-vision:11b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "11b"),
|
||||
ModelEntry("minicpm-v:latest", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "8b"),
|
||||
ModelEntry("llava:13b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
ModelEntry("nomic-embed-text:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.1b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.ROADRUNNER, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
]
|
||||
|
||||
# ── Wile (100.110.190.12) ────────────────────────────────────────────
|
||||
|
||||
WILE_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("llama3:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("gemma2:27b", Node.WILE, [Capability.CHAT], Tier.MEDIUM, "27b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:32b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "32b"),
|
||||
ModelEntry("deepseek-coder:33b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "33b"),
|
||||
ModelEntry("codestral:latest", Node.WILE, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
# Vision
|
||||
ModelEntry("llava:13b", Node.WILE, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.WILE, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q5_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("mixtral:8x22b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "141b"),
|
||||
ModelEntry("qwen2:72b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "72b"),
|
||||
]
|
||||
|
||||
ALL_MODELS = ROADRUNNER_MODELS + WILE_MODELS
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry of all available models and their capabilities."""
|
||||
|
||||
def __init__(self, models: list[ModelEntry] | None = None):
|
||||
self.models = models or ALL_MODELS
|
||||
self._by_name: dict[str, list[ModelEntry]] = {}
|
||||
self._by_capability: dict[Capability, list[ModelEntry]] = {}
|
||||
self._by_node: dict[Node, list[ModelEntry]] = {}
|
||||
self._index()
|
||||
|
||||
def _index(self):
|
||||
for m in self.models:
|
||||
self._by_name.setdefault(m.name, []).append(m)
|
||||
for cap in m.capabilities:
|
||||
self._by_capability.setdefault(cap, []).append(m)
|
||||
self._by_node.setdefault(m.node, []).append(m)
|
||||
|
||||
def find(
|
||||
self,
|
||||
capability: Capability | None = None,
|
||||
tier: Tier | None = None,
|
||||
node: Node | None = None,
|
||||
) -> list[ModelEntry]:
|
||||
"""Find models matching all given criteria."""
|
||||
results = list(self.models)
|
||||
if capability:
|
||||
results = [m for m in results if capability in m.capabilities]
|
||||
if tier:
|
||||
results = [m for m in results if m.tier == tier]
|
||||
if node:
|
||||
results = [m for m in results if m.node == node]
|
||||
return results
|
||||
|
||||
def get_best(
|
||||
self,
|
||||
capability: Capability,
|
||||
prefer_tier: Tier | None = None,
|
||||
prefer_node: Node | None = None,
|
||||
) -> ModelEntry | None:
|
||||
"""Get the best model for a capability, with optional preference."""
|
||||
candidates = self.find(capability=capability, tier=prefer_tier, node=prefer_node)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability, tier=prefer_tier)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def list_nodes(self) -> list[Node]:
|
||||
return list(self._by_node.keys())
|
||||
|
||||
def list_models_on_node(self, node: Node) -> list[ModelEntry]:
|
||||
return self._by_node.get(node, [])
|
||||
|
||||
def to_dict(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": m.name,
|
||||
"node": m.node.value,
|
||||
"capabilities": [c.value for c in m.capabilities],
|
||||
"tier": m.tier.value,
|
||||
"param_size": m.param_size,
|
||||
}
|
||||
for m in self.models
|
||||
]
|
||||
|
||||
|
||||
# Singleton
|
||||
registry = ModelRegistry()
|
||||
183
backend/app/agents/router.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Task router — auto-selects the right model + node for each task type.
|
||||
|
||||
Routes based on task characteristics:
|
||||
- Quick chat → fast models via cluster
|
||||
- Deep analysis → 70B+ models on Wile
|
||||
- Code/script analysis → code models (32b on Wile, 7b for quick)
|
||||
- Vision/image → vision models on Roadrunner
|
||||
- Embedding → embedding models on either node
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.config import settings
|
||||
from .registry import Capability, Tier, Node, ModelEntry, registry
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
QUICK_CHAT = "quick_chat"
|
||||
DEEP_ANALYSIS = "deep_analysis"
|
||||
CODE_ANALYSIS = "code_analysis"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
DEBATE_PLANNER = "debate_planner"
|
||||
DEBATE_CRITIC = "debate_critic"
|
||||
DEBATE_PRAGMATIST = "debate_pragmatist"
|
||||
DEBATE_JUDGE = "debate_judge"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of the routing decision."""
|
||||
model: str
|
||||
node: Node
|
||||
task_type: TaskType
|
||||
provider_type: str # "ollama" or "openwebui"
|
||||
reason: str
|
||||
|
||||
|
||||
class TaskRouter:
|
||||
"""Routes tasks to the appropriate model and node."""
|
||||
|
||||
# Default routing rules: task_type → (capability, preferred_tier, preferred_node)
|
||||
ROUTING_RULES: dict[TaskType, tuple[Capability, Tier | None, Node | None]] = {
|
||||
TaskType.QUICK_CHAT: (Capability.CHAT, Tier.FAST, None),
|
||||
TaskType.DEEP_ANALYSIS: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.CODE_ANALYSIS: (Capability.CODE, Tier.MEDIUM, Node.WILE),
|
||||
TaskType.VISION: (Capability.VISION, None, Node.ROADRUNNER),
|
||||
TaskType.EMBEDDING: (Capability.EMBEDDING, Tier.FAST, None),
|
||||
TaskType.DEBATE_PLANNER: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_CRITIC: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_PRAGMATIST: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_JUDGE: (Capability.CHAT, Tier.MEDIUM, Node.WILE),
|
||||
}
|
||||
|
||||
# Specific model overrides for debate roles (use diverse models for diversity of thought)
|
||||
DEBATE_MODEL_OVERRIDES: dict[TaskType, str] = {
|
||||
TaskType.DEBATE_PLANNER: "llama3.1:70b-instruct-q4_K_M",
|
||||
TaskType.DEBATE_CRITIC: "qwen2:72b-instruct",
|
||||
TaskType.DEBATE_PRAGMATIST: "mixtral:8x22b-instruct",
|
||||
TaskType.DEBATE_JUDGE: "gemma2:27b",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
|
||||
def route(self, task_type: TaskType, model_override: str | None = None) -> RoutingDecision:
|
||||
"""Decide which model and node to use for a task."""
|
||||
|
||||
# Explicit model override
|
||||
if model_override:
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_override:
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Explicit model override: {model_override}",
|
||||
)
|
||||
# Model not in registry — try via cluster
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"Override model {model_override} not in registry, routing to cluster",
|
||||
)
|
||||
|
||||
# Debate model overrides
|
||||
if task_type in self.DEBATE_MODEL_OVERRIDES:
|
||||
model_name = self.DEBATE_MODEL_OVERRIDES[task_type]
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_name:
|
||||
return RoutingDecision(
|
||||
model=model_name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Debate role {task_type.value} → {model_name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Standard routing
|
||||
cap, tier, node = self.ROUTING_RULES.get(
|
||||
task_type,
|
||||
(Capability.CHAT, Tier.FAST, None),
|
||||
)
|
||||
|
||||
entry = self.registry.get_best(cap, prefer_tier=tier, prefer_node=node)
|
||||
if entry:
|
||||
return RoutingDecision(
|
||||
model=entry.name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Auto-routed {task_type.value}: {cap.value}/{tier.value if tier else 'any'} → {entry.name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Fallback to cluster
|
||||
default_model = settings.DEFAULT_FAST_MODEL
|
||||
return RoutingDecision(
|
||||
model=default_model,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"No registry match, falling back to cluster with {default_model}",
|
||||
)
|
||||
|
||||
def get_provider(self, decision: RoutingDecision):
|
||||
"""Create the appropriate provider for a routing decision."""
|
||||
if decision.provider_type == "openwebui":
|
||||
return OpenWebUIProvider(model=decision.model)
|
||||
else:
|
||||
return OllamaProvider(model=decision.model, node=decision.node)
|
||||
|
||||
def get_embedding_provider(self, model: str | None = None, node: Node | None = None) -> EmbeddingProvider:
|
||||
"""Get an embedding provider."""
|
||||
return EmbeddingProvider(
|
||||
model=model or settings.DEFAULT_EMBEDDING_MODEL,
|
||||
node=node or Node.ROADRUNNER,
|
||||
)
|
||||
|
||||
def classify_task(self, query: str, has_image: bool = False) -> TaskType:
|
||||
"""Heuristic classification of query into task type.
|
||||
|
||||
In practice this could be enhanced by a classifier model, but
|
||||
keyword heuristics work well for routing.
|
||||
"""
|
||||
if has_image:
|
||||
return TaskType.VISION
|
||||
|
||||
q = query.lower()
|
||||
|
||||
# Code/script indicators
|
||||
code_indicators = [
|
||||
"deobfuscate", "decode", "powershell", "script", "base64",
|
||||
"command line", "cmdline", "commandline", "obfuscated",
|
||||
"malware", "shellcode", "vbs", "vbscript", "batch",
|
||||
"python script", "code review", "reverse engineer",
|
||||
]
|
||||
if any(ind in q for ind in code_indicators):
|
||||
return TaskType.CODE_ANALYSIS
|
||||
|
||||
# Deep analysis indicators
|
||||
deep_indicators = [
|
||||
"deep analysis", "detailed", "comprehensive", "thorough",
|
||||
"investigate", "root cause", "advanced", "explain in detail",
|
||||
"full analysis", "forensic",
|
||||
]
|
||||
if any(ind in q for ind in deep_indicators):
|
||||
return TaskType.DEEP_ANALYSIS
|
||||
|
||||
return TaskType.QUICK_CHAT
|
||||
|
||||
|
||||
# Singleton
|
||||
task_router = TaskRouter()
|
||||
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routes initialization."""
|
||||
1
backend/app/api/routes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API route modules."""
|
||||
170
backend/app/api/routes/agent.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""API routes for analyst-assist agent."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
|
||||
from app.agents.config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance (lazy-loaded)
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
"""Get or create the agent instance."""
|
||||
global _agent
|
||||
if _agent is None:
|
||||
if not AgentConfig.is_agent_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Analyst-assist agent is not configured. "
|
||||
"Please configure an LLM provider.",
|
||||
)
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
"""Request for agent assistance."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: str | None = Field(
|
||||
None, description="Name of CSV dataset being analyzed"
|
||||
)
|
||||
artifact_type: str | None = Field(
|
||||
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
|
||||
)
|
||||
host_identifier: str | None = Field(
|
||||
None, description="Host name, IP address, or identifier"
|
||||
)
|
||||
data_summary: str | None = Field(
|
||||
None, description="Brief summary or context about the uploaded data"
|
||||
)
|
||||
conversation_history: list[dict] | None = Field(
|
||||
None, description="Previous messages for context"
|
||||
)
|
||||
|
||||
|
||||
class AssistResponse(BaseModel):
|
||||
"""Response with agent guidance."""
|
||||
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponse,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
|
||||
"Agent provides advisory guidance only - no execution.",
|
||||
)
|
||||
async def agent_assist(request: AssistRequest) -> AssistResponse:
|
||||
"""Provide analyst-assist guidance on artifact data.
|
||||
|
||||
The agent will:
|
||||
- Explain and interpret the provided data context
|
||||
- Suggest analytical pivots the analyst might explore
|
||||
- Suggest data filters or queries that might be useful
|
||||
- Highlight assumptions, limitations, and caveats
|
||||
|
||||
The agent will NOT:
|
||||
- Execute any tools or actions
|
||||
- Escalate findings to alerts
|
||||
- Modify any data or schema
|
||||
- Make autonomous decisions
|
||||
|
||||
Args:
|
||||
request: Assistance request with query and context
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not configured (503) or request fails
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
|
||||
# Build context
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
)
|
||||
|
||||
# Get guidance
|
||||
response = await agent.assist(context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted analyst with query: {request.query[:50]}... "
|
||||
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
|
||||
)
|
||||
|
||||
return AssistResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Agent error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Agent unavailable: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in agent_assist: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error generating guidance. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent health",
|
||||
description="Check if agent is configured and ready to assist.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
"""Check agent availability and configuration.
|
||||
|
||||
Returns:
|
||||
Health status with configuration details
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": provider_type,
|
||||
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
|
||||
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
|
||||
}
|
||||
except HTTPException:
|
||||
return {
|
||||
"status": "unavailable",
|
||||
"reason": "No LLM provider configured",
|
||||
"configured_providers": {
|
||||
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
|
||||
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
|
||||
"online": bool(AgentConfig.ONLINE_API_KEY),
|
||||
},
|
||||
}
|
||||
265
backend/app/api/routes/agent_v2.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
global _agent
|
||||
if _agent is None:
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
query: str = Field(..., max_length=4000, description="Analyst question")
|
||||
dataset_name: str | None = None
|
||||
artifact_type: str | None = None
|
||||
host_identifier: str | None = None
|
||||
data_summary: str | None = None
|
||||
conversation_history: list[dict] | None = None
|
||||
active_hypotheses: list[str] | None = None
|
||||
annotations_summary: str | None = None
|
||||
enrichment_summary: str | None = None
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
sans_references: list[str] = []
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponseModel,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance with auto-routed model selection. "
|
||||
"Supports quick (fast), deep (70B), and debate (multi-model) modes.",
|
||||
)
|
||||
async def agent_assist(
|
||||
request: AssistRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
active_hypotheses=request.active_hypotheses or [],
|
||||
annotations_summary=request.annotations_summary,
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
|
||||
# Persist conversation
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db, conv_id, request, response
|
||||
)
|
||||
|
||||
return AssistResponseModel(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
perspectives=[
|
||||
{
|
||||
"role": p.role,
|
||||
"content": p.content,
|
||||
"model_used": p.model_used,
|
||||
"node_used": p.node_used,
|
||||
"latency_ms": p.latency_ms,
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Agent error: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist/stream",
|
||||
summary="Stream agent response",
|
||||
description="Stream tokens via SSE for real-time display.",
|
||||
)
|
||||
async def agent_assist_stream(request: AssistRequest):
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
mode="quick", # streaming only supports quick mode
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
async for token in agent.assist_stream(context):
|
||||
yield f"data: {json.dumps({'token': token})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent and node health",
|
||||
description="Returns availability of all LLM nodes and the cluster.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
nodes = await check_all_nodes()
|
||||
rag_health = await sans_rag.health_check()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"nodes": nodes,
|
||||
"rag": rag_health,
|
||||
"default_models": {
|
||||
"fast": settings.DEFAULT_FAST_MODEL,
|
||||
"heavy": settings.DEFAULT_HEAVY_MODEL,
|
||||
"code": settings.DEFAULT_CODE_MODEL,
|
||||
"vision": settings.DEFAULT_VISION_MODEL,
|
||||
"embedding": settings.DEFAULT_EMBEDDING_MODEL,
|
||||
},
|
||||
"config": {
|
||||
"max_tokens": settings.AGENT_MAX_TOKENS,
|
||||
"temperature": settings.AGENT_TEMPERATURE,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List all available models",
|
||||
description="Returns the full model registry with capabilities and node assignments.",
|
||||
)
|
||||
async def list_models():
|
||||
return {
|
||||
"models": registry.to_dict(),
|
||||
"total": len(registry.models),
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
db: AsyncSession,
|
||||
conversation_id: str | None,
|
||||
request: AssistRequest,
|
||||
response: AgentResponse,
|
||||
) -> str:
|
||||
"""Save user message and agent response to the database."""
|
||||
if conversation_id:
|
||||
# Find existing conversation
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
conv = Conversation(id=conversation_id, hunt_id=request.hunt_id)
|
||||
db.add(conv)
|
||||
else:
|
||||
conv = Conversation(
|
||||
title=request.query[:100],
|
||||
hunt_id=request.hunt_id,
|
||||
)
|
||||
db.add(conv)
|
||||
await db.flush()
|
||||
|
||||
# User message
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.query,
|
||||
)
|
||||
db.add(user_msg)
|
||||
|
||||
# Agent message
|
||||
agent_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="agent",
|
||||
content=response.guidance,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
response_meta={
|
||||
"confidence": response.confidence,
|
||||
"pivots": response.suggested_pivots,
|
||||
"filters": response.suggested_filters,
|
||||
"sans_refs": response.sans_references,
|
||||
},
|
||||
)
|
||||
db.add(agent_msg)
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
404
backend/app/api/routes/alerts.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""API routes for alerts — CRUD, analyze triggers, and alert rules."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Alert, AlertRule, _new_id, _utcnow
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.analyzers import (
|
||||
get_available_analyzers,
|
||||
get_analyzer,
|
||||
run_all_analyzers,
|
||||
AlertCandidate,
|
||||
)
|
||||
from app.services.process_tree import _fetch_rows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/alerts", tags=["alerts"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AlertUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class RuleCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
analyzer: str
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: bool = True
|
||||
hunt_id: Optional[str] = None
|
||||
|
||||
|
||||
class RuleUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[dict] = None
|
||||
severity_override: Optional[str] = None
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
analyzers: Optional[list[str]] = None # None = run all
|
||||
config: Optional[dict] = None
|
||||
auto_create: bool = True # automatically persist alerts
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _alert_to_dict(a: Alert) -> dict:
|
||||
return {
|
||||
"id": a.id,
|
||||
"title": a.title,
|
||||
"description": a.description,
|
||||
"severity": a.severity,
|
||||
"status": a.status,
|
||||
"analyzer": a.analyzer,
|
||||
"score": a.score,
|
||||
"evidence": a.evidence or [],
|
||||
"mitre_technique": a.mitre_technique,
|
||||
"tags": a.tags or [],
|
||||
"hunt_id": a.hunt_id,
|
||||
"dataset_id": a.dataset_id,
|
||||
"case_id": a.case_id,
|
||||
"assignee": a.assignee,
|
||||
"acknowledged_at": a.acknowledged_at.isoformat() if a.acknowledged_at else None,
|
||||
"resolved_at": a.resolved_at.isoformat() if a.resolved_at else None,
|
||||
"created_at": a.created_at.isoformat() if a.created_at else None,
|
||||
"updated_at": a.updated_at.isoformat() if a.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _rule_to_dict(r: AlertRule) -> dict:
|
||||
return {
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"analyzer": r.analyzer,
|
||||
"config": r.config,
|
||||
"severity_override": r.severity_override,
|
||||
"enabled": r.enabled,
|
||||
"hunt_id": r.hunt_id,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Alert CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List alerts")
|
||||
async def list_alerts(
|
||||
status: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
analyzer: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
dataset_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Alert)
|
||||
count_stmt = select(func.count(Alert.id))
|
||||
if status:
|
||||
stmt = stmt.where(Alert.status == status)
|
||||
count_stmt = count_stmt.where(Alert.status == status)
|
||||
if severity:
|
||||
stmt = stmt.where(Alert.severity == severity)
|
||||
count_stmt = count_stmt.where(Alert.severity == severity)
|
||||
if analyzer:
|
||||
stmt = stmt.where(Alert.analyzer == analyzer)
|
||||
count_stmt = count_stmt.where(Alert.analyzer == analyzer)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Alert.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Alert.hunt_id == hunt_id)
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Alert.dataset_id == dataset_id)
|
||||
count_stmt = count_stmt.where(Alert.dataset_id == dataset_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Alert.score), desc(Alert.created_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"alerts": [_alert_to_dict(a) for a in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/stats", summary="Alert statistics dashboard")
|
||||
async def alert_stats(
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return aggregated alert statistics."""
|
||||
base = select(Alert)
|
||||
if hunt_id:
|
||||
base = base.where(Alert.hunt_id == hunt_id)
|
||||
|
||||
# Severity breakdown
|
||||
sev_stmt = select(Alert.severity, func.count(Alert.id)).group_by(Alert.severity)
|
||||
if hunt_id:
|
||||
sev_stmt = sev_stmt.where(Alert.hunt_id == hunt_id)
|
||||
sev_rows = (await db.execute(sev_stmt)).all()
|
||||
severity_counts = {s: c for s, c in sev_rows}
|
||||
|
||||
# Status breakdown
|
||||
status_stmt = select(Alert.status, func.count(Alert.id)).group_by(Alert.status)
|
||||
if hunt_id:
|
||||
status_stmt = status_stmt.where(Alert.hunt_id == hunt_id)
|
||||
status_rows = (await db.execute(status_stmt)).all()
|
||||
status_counts = {s: c for s, c in status_rows}
|
||||
|
||||
# Analyzer breakdown
|
||||
analyzer_stmt = select(Alert.analyzer, func.count(Alert.id)).group_by(Alert.analyzer)
|
||||
if hunt_id:
|
||||
analyzer_stmt = analyzer_stmt.where(Alert.hunt_id == hunt_id)
|
||||
analyzer_rows = (await db.execute(analyzer_stmt)).all()
|
||||
analyzer_counts = {a: c for a, c in analyzer_rows}
|
||||
|
||||
# Top MITRE techniques
|
||||
mitre_stmt = (
|
||||
select(Alert.mitre_technique, func.count(Alert.id))
|
||||
.where(Alert.mitre_technique.isnot(None))
|
||||
.group_by(Alert.mitre_technique)
|
||||
.order_by(desc(func.count(Alert.id)))
|
||||
.limit(10)
|
||||
)
|
||||
if hunt_id:
|
||||
mitre_stmt = mitre_stmt.where(Alert.hunt_id == hunt_id)
|
||||
mitre_rows = (await db.execute(mitre_stmt)).all()
|
||||
top_mitre = [{"technique": t, "count": c} for t, c in mitre_rows]
|
||||
|
||||
total = sum(severity_counts.values())
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"severity_counts": severity_counts,
|
||||
"status_counts": status_counts,
|
||||
"analyzer_counts": analyzer_counts,
|
||||
"top_mitre": top_mitre,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{alert_id}", summary="Get alert detail")
|
||||
async def get_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.get(Alert, alert_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
return _alert_to_dict(result)
|
||||
|
||||
|
||||
@router.put("/{alert_id}", summary="Update alert (status, assignee, etc.)")
|
||||
async def update_alert(
|
||||
alert_id: str, body: AlertUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
|
||||
if body.status is not None:
|
||||
alert.status = body.status
|
||||
if body.status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if body.status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
if body.severity is not None:
|
||||
alert.severity = body.severity
|
||||
if body.assignee is not None:
|
||||
alert.assignee = body.assignee
|
||||
if body.case_id is not None:
|
||||
alert.case_id = body.case_id
|
||||
if body.tags is not None:
|
||||
alert.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(alert)
|
||||
return _alert_to_dict(alert)
|
||||
|
||||
|
||||
@router.delete("/{alert_id}", summary="Delete alert")
|
||||
async def delete_alert(alert_id: str, db: AsyncSession = Depends(get_db)):
|
||||
alert = await db.get(Alert, alert_id)
|
||||
if not alert:
|
||||
raise HTTPException(status_code=404, detail="Alert not found")
|
||||
await db.delete(alert)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Bulk operations ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/bulk-update", summary="Bulk update alert statuses")
|
||||
async def bulk_update_alerts(
|
||||
alert_ids: list[str],
|
||||
status: str = Query(...),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
updated = 0
|
||||
for aid in alert_ids:
|
||||
alert = await db.get(Alert, aid)
|
||||
if alert:
|
||||
alert.status = status
|
||||
if status == "acknowledged" and not alert.acknowledged_at:
|
||||
alert.acknowledged_at = _utcnow()
|
||||
if status in ("resolved", "false-positive") and not alert.resolved_at:
|
||||
alert.resolved_at = _utcnow()
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
# ── Run Analyzers ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/analyzers/list", summary="List available analyzers")
|
||||
async def list_analyzers():
|
||||
return {"analyzers": get_available_analyzers()}
|
||||
|
||||
|
||||
@router.post("/analyze", summary="Run analyzers on a dataset/hunt and optionally create alerts")
|
||||
async def run_analysis(
|
||||
request: AnalyzeRequest, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db, dataset_id=request.dataset_id, hunt_id=request.hunt_id, limit=10000,
|
||||
)
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found")
|
||||
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Run analyzers
|
||||
candidates = await run_all_analyzers(rows, enabled=request.analyzers, config=request.config)
|
||||
|
||||
created_alerts: list[dict] = []
|
||||
if request.auto_create and candidates:
|
||||
for c in candidates:
|
||||
alert = Alert(
|
||||
id=_new_id(),
|
||||
title=c.title,
|
||||
description=c.description,
|
||||
severity=c.severity,
|
||||
analyzer=c.analyzer,
|
||||
score=c.score,
|
||||
evidence=c.evidence,
|
||||
mitre_technique=c.mitre_technique,
|
||||
tags=c.tags,
|
||||
hunt_id=request.hunt_id,
|
||||
dataset_id=request.dataset_id,
|
||||
)
|
||||
db.add(alert)
|
||||
created_alerts.append(_alert_to_dict(alert))
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"candidates_found": len(candidates),
|
||||
"alerts_created": len(created_alerts),
|
||||
"alerts": created_alerts,
|
||||
"summary": {
|
||||
"by_severity": _count_by(candidates, "severity"),
|
||||
"by_analyzer": _count_by(candidates, "analyzer"),
|
||||
"rows_analyzed": len(rows),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _count_by(items: list[AlertCandidate], attr: str) -> dict[str, int]:
|
||||
counts: dict[str, int] = {}
|
||||
for item in items:
|
||||
key = getattr(item, attr, "unknown")
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return counts
|
||||
|
||||
|
||||
# ── Alert Rules CRUD ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/rules/list", summary="List alert rules")
|
||||
async def list_rules(
|
||||
enabled: bool | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(AlertRule)
|
||||
if enabled is not None:
|
||||
stmt = stmt.where(AlertRule.enabled == enabled)
|
||||
results = (await db.execute(stmt.order_by(AlertRule.created_at))).scalars().all()
|
||||
return {"rules": [_rule_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.post("/rules", summary="Create alert rule")
|
||||
async def create_rule(body: RuleCreate, db: AsyncSession = Depends(get_db)):
|
||||
# Validate analyzer exists
|
||||
if not get_analyzer(body.analyzer):
|
||||
raise HTTPException(status_code=400, detail=f"Unknown analyzer: {body.analyzer}")
|
||||
|
||||
rule = AlertRule(
|
||||
id=_new_id(),
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
analyzer=body.analyzer,
|
||||
config=body.config,
|
||||
severity_override=body.severity_override,
|
||||
enabled=body.enabled,
|
||||
hunt_id=body.hunt_id,
|
||||
)
|
||||
db.add(rule)
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.put("/rules/{rule_id}", summary="Update alert rule")
|
||||
async def update_rule(
|
||||
rule_id: str, body: RuleUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
|
||||
if body.name is not None:
|
||||
rule.name = body.name
|
||||
if body.description is not None:
|
||||
rule.description = body.description
|
||||
if body.config is not None:
|
||||
rule.config = body.config
|
||||
if body.severity_override is not None:
|
||||
rule.severity_override = body.severity_override
|
||||
if body.enabled is not None:
|
||||
rule.enabled = body.enabled
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(rule)
|
||||
return _rule_to_dict(rule)
|
||||
|
||||
|
||||
@router.delete("/rules/{rule_id}", summary="Delete alert rule")
|
||||
async def delete_rule(rule_id: str, db: AsyncSession = Depends(get_db)):
|
||||
rule = await db.get(AlertRule, rule_id)
|
||||
if not rule:
|
||||
raise HTTPException(status_code=404, detail="Rule not found")
|
||||
await db.delete(rule)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
295
backend/app/api/routes/analysis.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""API routes for process trees, storyline graphs, risk scoring, LLM analysis, timeline, and field stats."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.process_tree import (
|
||||
build_process_tree,
|
||||
build_storyline,
|
||||
compute_risk_scores,
|
||||
_fetch_rows,
|
||||
)
|
||||
from app.services.llm_analysis import (
|
||||
AnalysisRequest,
|
||||
AnalysisResult,
|
||||
run_llm_analysis,
|
||||
)
|
||||
from app.services.timeline import (
|
||||
build_timeline_bins,
|
||||
compute_field_stats,
|
||||
search_rows,
|
||||
)
|
||||
from app.services.mitre import (
|
||||
map_to_attack,
|
||||
build_knowledge_graph,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/analysis", tags=["analysis"])
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ProcessTreeResponse(BaseModel):
|
||||
trees: list[dict] = Field(default_factory=list)
|
||||
total_processes: int = 0
|
||||
|
||||
|
||||
class StorylineResponse(BaseModel):
|
||||
nodes: list[dict] = Field(default_factory=list)
|
||||
edges: list[dict] = Field(default_factory=list)
|
||||
summary: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RiskHostEntry(BaseModel):
|
||||
hostname: str
|
||||
score: int = 0
|
||||
signals: list[str] = Field(default_factory=list)
|
||||
event_count: int = 0
|
||||
process_count: int = 0
|
||||
network_count: int = 0
|
||||
file_count: int = 0
|
||||
|
||||
|
||||
class RiskSummaryResponse(BaseModel):
|
||||
hosts: list[RiskHostEntry] = Field(default_factory=list)
|
||||
overall_score: int = 0
|
||||
total_events: int = 0
|
||||
severity_breakdown: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/process-tree",
|
||||
response_model=ProcessTreeResponse,
|
||||
summary="Build process tree from dataset rows",
|
||||
description=(
|
||||
"Extracts parent→child process relationships from dataset rows "
|
||||
"and returns a hierarchical forest of process nodes."
|
||||
),
|
||||
)
|
||||
async def get_process_tree(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return process tree(s) for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
trees = await build_process_tree(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
|
||||
# Count total processes recursively
|
||||
def _count(node: dict) -> int:
|
||||
return 1 + sum(_count(c) for c in node.get("children", []))
|
||||
|
||||
total = sum(_count(t) for t in trees)
|
||||
|
||||
return ProcessTreeResponse(trees=trees, total_processes=total)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/storyline",
|
||||
response_model=StorylineResponse,
|
||||
summary="Build CrowdStrike-style storyline attack graph",
|
||||
description=(
|
||||
"Creates a Cytoscape-compatible graph of events connected by "
|
||||
"process lineage (spawned) and temporal sequence within each host."
|
||||
),
|
||||
)
|
||||
async def get_storyline(
|
||||
dataset_id: str | None = Query(None, description="Dataset ID"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID (scans all datasets in hunt)"),
|
||||
hostname: str | None = Query(None, description="Filter by hostname"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a storyline graph for a dataset or hunt."""
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
result = await build_storyline(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id, hostname_filter=hostname,
|
||||
)
|
||||
|
||||
return StorylineResponse(**result)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/risk-summary",
|
||||
response_model=RiskSummaryResponse,
|
||||
summary="Compute risk scores per host",
|
||||
description=(
|
||||
"Analyzes dataset rows for suspicious patterns (encoded PowerShell, "
|
||||
"credential dumping, lateral movement) and produces per-host risk scores."
|
||||
),
|
||||
)
|
||||
async def get_risk_summary(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return risk scores for all hosts in a hunt."""
|
||||
result = await compute_risk_scores(db, hunt_id=hunt_id)
|
||||
return RiskSummaryResponse(**result)
|
||||
|
||||
|
||||
# ── LLM Analysis ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/llm-analyze",
|
||||
response_model=AnalysisResult,
|
||||
summary="Run LLM-powered threat analysis on dataset",
|
||||
description=(
|
||||
"Loads dataset rows server-side, builds a summary, and sends to "
|
||||
"Wile (deep analysis) or Roadrunner (quick) for comprehensive "
|
||||
"threat analysis. Returns structured findings, IOCs, MITRE techniques."
|
||||
),
|
||||
)
|
||||
async def llm_analyze(
|
||||
request: AnalysisRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Run LLM analysis on a dataset or hunt."""
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
|
||||
# Load rows
|
||||
rows_objs = await _fetch_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
limit=2000,
|
||||
)
|
||||
|
||||
if not rows_objs:
|
||||
raise HTTPException(status_code=404, detail="No rows found for analysis")
|
||||
|
||||
# Extract data dicts
|
||||
rows = [r.normalized_data or r.data for r in rows_objs]
|
||||
|
||||
# Get dataset name
|
||||
ds_name = "hunt datasets"
|
||||
if request.dataset_id:
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(request.dataset_id)
|
||||
if ds:
|
||||
ds_name = ds.name
|
||||
|
||||
result = await run_llm_analysis(rows, request, dataset_name=ds_name)
|
||||
return result
|
||||
|
||||
|
||||
# ── Timeline ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/timeline",
|
||||
summary="Get event timeline histogram bins",
|
||||
)
|
||||
async def get_timeline(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
bins: int = Query(60, ge=10, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_timeline_bins(db, dataset_id=dataset_id, hunt_id=hunt_id, bins=bins)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/field-stats",
|
||||
summary="Get per-field value distributions",
|
||||
)
|
||||
async def get_field_stats(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
fields: str | None = Query(None, description="Comma-separated field names"),
|
||||
top_n: int = Query(20, ge=5, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
field_list = [f.strip() for f in fields.split(",")] if fields else None
|
||||
return await compute_field_stats(
|
||||
db, dataset_id=dataset_id, hunt_id=hunt_id,
|
||||
fields=field_list, top_n=top_n,
|
||||
)
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
query: str = ""
|
||||
filters: dict[str, str] = Field(default_factory=dict)
|
||||
time_start: Optional[str] = None
|
||||
time_end: Optional[str] = None
|
||||
limit: int = 500
|
||||
offset: int = 0
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search",
|
||||
summary="Search and filter dataset rows",
|
||||
)
|
||||
async def search_dataset_rows(
|
||||
request: SearchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not request.dataset_id and not request.hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await search_rows(
|
||||
db,
|
||||
dataset_id=request.dataset_id,
|
||||
hunt_id=request.hunt_id,
|
||||
query=request.query,
|
||||
filters=request.filters,
|
||||
time_start=request.time_start,
|
||||
time_end=request.time_end,
|
||||
limit=request.limit,
|
||||
offset=request.offset,
|
||||
)
|
||||
|
||||
|
||||
# ── MITRE ATT&CK ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/mitre-map",
|
||||
summary="Map dataset events to MITRE ATT&CK techniques",
|
||||
)
|
||||
async def get_mitre_map(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await map_to_attack(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/knowledge-graph",
|
||||
summary="Build entity-technique knowledge graph",
|
||||
)
|
||||
async def get_knowledge_graph(
|
||||
dataset_id: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if not dataset_id and not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="Provide dataset_id or hunt_id")
|
||||
return await build_knowledge_graph(db, dataset_id=dataset_id, hunt_id=hunt_id)
|
||||
311
backend/app/api/routes/annotations.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""API routes for annotations and hypotheses."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Annotation, Hypothesis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["annotations"])
|
||||
|
||||
|
||||
# ── Annotation models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
row_id: int | None = None
|
||||
dataset_id: str | None = None
|
||||
text: str = Field(..., max_length=2000)
|
||||
severity: str = Field(default="info") # info|low|medium|high|critical
|
||||
tag: str | None = None # suspicious|benign|needs-review
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
text: str | None = None
|
||||
severity: str | None = None
|
||||
tag: str | None = None
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
id: str
|
||||
row_id: int | None
|
||||
dataset_id: str | None
|
||||
author_id: str | None
|
||||
text: str
|
||||
severity: str
|
||||
tag: str | None
|
||||
highlight_color: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
annotations: list[AnnotationResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Hypothesis models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HypothesisCreate(BaseModel):
|
||||
hunt_id: str | None = None
|
||||
title: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str = Field(default="draft")
|
||||
|
||||
|
||||
class HypothesisUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str | None = None # draft|active|confirmed|rejected
|
||||
evidence_row_ids: list[int] | None = None
|
||||
evidence_notes: str | None = None
|
||||
|
||||
|
||||
class HypothesisResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str | None
|
||||
title: str
|
||||
description: str | None
|
||||
mitre_technique: str | None
|
||||
status: str
|
||||
evidence_row_ids: list | None
|
||||
evidence_notes: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class HypothesisListResponse(BaseModel):
|
||||
hypotheses: list[HypothesisResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Annotation routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
ann_router = APIRouter(prefix="/api/annotations")
|
||||
|
||||
|
||||
@ann_router.post("", response_model=AnnotationResponse, summary="Create annotation")
|
||||
async def create_annotation(body: AnnotationCreate, db: AsyncSession = Depends(get_db)):
|
||||
ann = Annotation(
|
||||
row_id=body.row_id,
|
||||
dataset_id=body.dataset_id,
|
||||
text=body.text,
|
||||
severity=body.severity,
|
||||
tag=body.tag,
|
||||
highlight_color=body.highlight_color,
|
||||
)
|
||||
db.add(ann)
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.get("", response_model=AnnotationListResponse, summary="List annotations")
|
||||
async def list_annotations(
|
||||
dataset_id: str | None = Query(None),
|
||||
row_id: int | None = Query(None),
|
||||
tag: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Annotation).order_by(Annotation.created_at.desc())
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Annotation.dataset_id == dataset_id)
|
||||
if row_id:
|
||||
stmt = stmt.where(Annotation.row_id == row_id)
|
||||
if tag:
|
||||
stmt = stmt.where(Annotation.tag == tag)
|
||||
if severity:
|
||||
stmt = stmt.where(Annotation.severity == severity)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
annotations = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Annotation.id))
|
||||
if dataset_id:
|
||||
count_stmt = count_stmt.where(Annotation.dataset_id == dataset_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return AnnotationListResponse(
|
||||
annotations=[
|
||||
AnnotationResponse(
|
||||
id=a.id, row_id=a.row_id, dataset_id=a.dataset_id,
|
||||
author_id=a.author_id, text=a.text, severity=a.severity,
|
||||
tag=a.tag, highlight_color=a.highlight_color,
|
||||
created_at=a.created_at.isoformat(), updated_at=a.updated_at.isoformat(),
|
||||
)
|
||||
for a in annotations
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@ann_router.put("/{annotation_id}", response_model=AnnotationResponse, summary="Update annotation")
|
||||
async def update_annotation(
|
||||
annotation_id: str, body: AnnotationUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
if body.text is not None:
|
||||
ann.text = body.text
|
||||
if body.severity is not None:
|
||||
ann.severity = body.severity
|
||||
if body.tag is not None:
|
||||
ann.tag = body.tag
|
||||
if body.highlight_color is not None:
|
||||
ann.highlight_color = body.highlight_color
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.delete("/{annotation_id}", summary="Delete annotation")
|
||||
async def delete_annotation(annotation_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
await db.delete(ann)
|
||||
return {"message": "Annotation deleted", "id": annotation_id}
|
||||
|
||||
|
||||
# ── Hypothesis routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
hyp_router = APIRouter(prefix="/api/hypotheses")
|
||||
|
||||
|
||||
@hyp_router.post("", response_model=HypothesisResponse, summary="Create hypothesis")
|
||||
async def create_hypothesis(body: HypothesisCreate, db: AsyncSession = Depends(get_db)):
|
||||
hyp = Hypothesis(
|
||||
hunt_id=body.hunt_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
mitre_technique=body.mitre_technique,
|
||||
status=body.status,
|
||||
)
|
||||
db.add(hyp)
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("", response_model=HypothesisListResponse, summary="List hypotheses")
|
||||
async def list_hypotheses(
|
||||
hunt_id: str | None = Query(None),
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hypothesis).order_by(Hypothesis.updated_at.desc())
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
if status:
|
||||
stmt = stmt.where(Hypothesis.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hyps = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hypothesis.id))
|
||||
if hunt_id:
|
||||
count_stmt = count_stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HypothesisListResponse(
|
||||
hypotheses=[
|
||||
HypothesisResponse(
|
||||
id=h.id, hunt_id=h.hunt_id, title=h.title,
|
||||
description=h.description, mitre_technique=h.mitre_technique,
|
||||
status=h.status, evidence_row_ids=h.evidence_row_ids,
|
||||
evidence_notes=h.evidence_notes,
|
||||
created_at=h.created_at.isoformat(), updated_at=h.updated_at.isoformat(),
|
||||
)
|
||||
for h in hyps
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("/{hypothesis_id}", response_model=HypothesisResponse, summary="Get hypothesis")
|
||||
async def get_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.put("/{hypothesis_id}", response_model=HypothesisResponse, summary="Update hypothesis")
|
||||
async def update_hypothesis(
|
||||
hypothesis_id: str, body: HypothesisUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
if body.title is not None:
|
||||
hyp.title = body.title
|
||||
if body.description is not None:
|
||||
hyp.description = body.description
|
||||
if body.mitre_technique is not None:
|
||||
hyp.mitre_technique = body.mitre_technique
|
||||
if body.status is not None:
|
||||
hyp.status = body.status
|
||||
if body.evidence_row_ids is not None:
|
||||
hyp.evidence_row_ids = body.evidence_row_ids
|
||||
if body.evidence_notes is not None:
|
||||
hyp.evidence_notes = body.evidence_notes
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.delete("/{hypothesis_id}", summary="Delete hypothesis")
|
||||
async def delete_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
await db.delete(hyp)
|
||||
return {"message": "Hypothesis deleted", "id": hypothesis_id}
|
||||
197
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
from app.services.auth import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_token_pair,
|
||||
decode_token,
|
||||
get_current_user,
|
||||
TokenPair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str = Field(..., min_length=3, max_length=64)
|
||||
email: str = Field(..., max_length=256)
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
role: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
user: UserResponse
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=AuthResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Register a new user",
|
||||
)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
# Check for existing username
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Username already taken",
|
||||
)
|
||||
|
||||
# Check for existing email
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
logger.info(f"New user registered: {user.username} ({user.id})")
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=AuthResponse,
|
||||
summary="Login with username and password",
|
||||
)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled",
|
||||
)
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=TokenPair,
|
||||
summary="Refresh access token",
|
||||
)
|
||||
async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
token_data = decode_token(body.refresh_token)
|
||||
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user",
|
||||
)
|
||||
|
||||
return create_token_pair(user.id, user.role)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get current user profile",
|
||||
)
|
||||
async def get_profile(user: User = Depends(get_current_user)):
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
296
backend/app/api/routes/cases.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""API routes for case management — CRUD for cases, tasks, and activity logs."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Case, CaseTask, ActivityLog, _new_id, _utcnow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/cases", tags=["cases"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
class CaseCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
severity: str = "medium"
|
||||
tlp: str = "amber"
|
||||
pap: str = "amber"
|
||||
priority: int = 2
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class CaseUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
severity: Optional[str] = None
|
||||
tlp: Optional[str] = None
|
||||
pap: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
priority: Optional[int] = None
|
||||
assignee: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
mitre_techniques: Optional[list[str]] = None
|
||||
iocs: Optional[list[dict]] = None
|
||||
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
assignee: Optional[str] = None
|
||||
order: Optional[int] = None
|
||||
|
||||
|
||||
# ── Helper: log activity ─────────────────────────────────────────────
|
||||
|
||||
async def _log_activity(
|
||||
db: AsyncSession,
|
||||
entity_type: str,
|
||||
entity_id: str,
|
||||
action: str,
|
||||
details: dict | None = None,
|
||||
):
|
||||
log = ActivityLog(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
action=action,
|
||||
details=details,
|
||||
created_at=_utcnow(),
|
||||
)
|
||||
db.add(log)
|
||||
|
||||
|
||||
# ── Case CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", summary="Create a case")
|
||||
async def create_case(body: CaseCreate, db: AsyncSession = Depends(get_db)):
|
||||
now = _utcnow()
|
||||
case = Case(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
severity=body.severity,
|
||||
tlp=body.tlp,
|
||||
pap=body.pap,
|
||||
priority=body.priority,
|
||||
assignee=body.assignee,
|
||||
tags=body.tags,
|
||||
hunt_id=body.hunt_id,
|
||||
mitre_techniques=body.mitre_techniques,
|
||||
iocs=body.iocs,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(case)
|
||||
await _log_activity(db, "case", case.id, "created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.get("", summary="List cases")
|
||||
async def list_cases(
|
||||
status: Optional[str] = Query(None),
|
||||
hunt_id: Optional[str] = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Case).order_by(desc(Case.updated_at))
|
||||
if status:
|
||||
q = q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
q = q.where(Case.hunt_id == hunt_id)
|
||||
q = q.offset(offset).limit(limit)
|
||||
result = await db.execute(q)
|
||||
cases = result.scalars().all()
|
||||
|
||||
count_q = select(func.count(Case.id))
|
||||
if status:
|
||||
count_q = count_q.where(Case.status == status)
|
||||
if hunt_id:
|
||||
count_q = count_q.where(Case.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_q)).scalar() or 0
|
||||
|
||||
return {"cases": [_case_to_dict(c) for c in cases], "total": total}
|
||||
|
||||
|
||||
@router.get("/{case_id}", summary="Get case detail")
|
||||
async def get_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.put("/{case_id}", summary="Update a case")
|
||||
async def update_case(case_id: str, body: CaseUpdate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
changes = {}
|
||||
for field in ["title", "description", "severity", "tlp", "pap", "status",
|
||||
"priority", "assignee", "tags", "mitre_techniques", "iocs"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
old = getattr(case, field)
|
||||
setattr(case, field, val)
|
||||
changes[field] = {"old": old, "new": val}
|
||||
if "status" in changes and changes["status"]["new"] == "in-progress" and not case.started_at:
|
||||
case.started_at = _utcnow()
|
||||
if "status" in changes and changes["status"]["new"] in ("resolved", "closed"):
|
||||
case.resolved_at = _utcnow()
|
||||
case.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case.id, "updated", changes)
|
||||
await db.commit()
|
||||
await db.refresh(case)
|
||||
return _case_to_dict(case)
|
||||
|
||||
|
||||
@router.delete("/{case_id}", summary="Delete a case")
|
||||
async def delete_case(case_id: str, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
await db.delete(case)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Task CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{case_id}/tasks", summary="Add task to case")
|
||||
async def create_task(case_id: str, body: TaskCreate, db: AsyncSession = Depends(get_db)):
|
||||
case = await db.get(Case, case_id)
|
||||
if not case:
|
||||
raise HTTPException(status_code=404, detail="Case not found")
|
||||
now = _utcnow()
|
||||
task = CaseTask(
|
||||
id=_new_id(),
|
||||
case_id=case_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
assignee=body.assignee,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(task)
|
||||
await _log_activity(db, "case", case_id, "task_created", {"title": body.title})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.put("/{case_id}/tasks/{task_id}", summary="Update a task")
|
||||
async def update_task(case_id: str, task_id: str, body: TaskUpdate, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
for field in ["title", "description", "status", "assignee", "order"]:
|
||||
val = getattr(body, field)
|
||||
if val is not None:
|
||||
setattr(task, field, val)
|
||||
task.updated_at = _utcnow()
|
||||
await _log_activity(db, "case", case_id, "task_updated", {"task_id": task_id})
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
|
||||
|
||||
@router.delete("/{case_id}/tasks/{task_id}", summary="Delete a task")
|
||||
async def delete_task(case_id: str, task_id: str, db: AsyncSession = Depends(get_db)):
|
||||
task = await db.get(CaseTask, task_id)
|
||||
if not task or task.case_id != case_id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
await db.delete(task)
|
||||
await db.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ── Activity Log ──────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{case_id}/activity", summary="Get case activity log")
|
||||
async def get_activity(
|
||||
case_id: str,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = (
|
||||
select(ActivityLog)
|
||||
.where(ActivityLog.entity_type == "case", ActivityLog.entity_id == case_id)
|
||||
.order_by(desc(ActivityLog.created_at))
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(q)
|
||||
logs = result.scalars().all()
|
||||
return {
|
||||
"logs": [
|
||||
{
|
||||
"id": l.id,
|
||||
"action": l.action,
|
||||
"details": l.details,
|
||||
"user_id": l.user_id,
|
||||
"created_at": l.created_at.isoformat() if l.created_at else None,
|
||||
}
|
||||
for l in logs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _case_to_dict(c: Case) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"title": c.title,
|
||||
"description": c.description,
|
||||
"severity": c.severity,
|
||||
"tlp": c.tlp,
|
||||
"pap": c.pap,
|
||||
"status": c.status,
|
||||
"priority": c.priority,
|
||||
"assignee": c.assignee,
|
||||
"tags": c.tags or [],
|
||||
"hunt_id": c.hunt_id,
|
||||
"owner_id": c.owner_id,
|
||||
"mitre_techniques": c.mitre_techniques or [],
|
||||
"iocs": c.iocs or [],
|
||||
"started_at": c.started_at.isoformat() if c.started_at else None,
|
||||
"resolved_at": c.resolved_at.isoformat() if c.resolved_at else None,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
"updated_at": c.updated_at.isoformat() if c.updated_at else None,
|
||||
"tasks": [_task_to_dict(t) for t in (c.tasks or [])],
|
||||
}
|
||||
|
||||
|
||||
def _task_to_dict(t: CaseTask) -> dict:
|
||||
return {
|
||||
"id": t.id,
|
||||
"case_id": t.case_id,
|
||||
"title": t.title,
|
||||
"description": t.description,
|
||||
"status": t.status,
|
||||
"assignee": t.assignee,
|
||||
"order": t.order,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"updated_at": t.updated_at.isoformat() if t.updated_at else None,
|
||||
}
|
||||
83
backend/app/api/routes/correlation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""API routes for cross-hunt correlation analysis."""
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.correlation import correlation_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/correlation", tags=["correlation"])
|
||||
|
||||
|
||||
class CorrelateRequest(BaseModel):
|
||||
hunt_ids: list[str] = Field(
|
||||
...,
|
||||
min_length=2,
|
||||
max_length=20,
|
||||
description="List of hunt IDs to correlate",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/analyze",
|
||||
summary="Run correlation analysis across hunts",
|
||||
description="Find shared IOCs, overlapping time windows, common MITRE techniques, "
|
||||
"and host patterns across the specified hunts.",
|
||||
)
|
||||
async def correlate_hunts(
|
||||
body: CorrelateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await correlation_engine.correlate_hunts(body.hunt_ids, db)
|
||||
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps],
|
||||
"host_overlaps": result.host_overlaps,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/all",
|
||||
summary="Correlate all hunts",
|
||||
description="Run correlation across all hunts in the system.",
|
||||
)
|
||||
async def correlate_all(db: AsyncSession = Depends(get_db)):
|
||||
result = await correlation_engine.correlate_all(db)
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps[:20]],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps[:10]],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps[:10]],
|
||||
"host_overlaps": result.host_overlaps[:10],
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ioc/{ioc_value}",
|
||||
summary="Find IOC across all hunts",
|
||||
description="Search for a specific IOC value across all datasets and hunts.",
|
||||
)
|
||||
async def find_ioc(
|
||||
ioc_value: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
occurrences = await correlation_engine.find_ioc_across_hunts(ioc_value, db)
|
||||
return {
|
||||
"ioc_value": ioc_value,
|
||||
"occurrences": occurrences,
|
||||
"total": len(occurrences),
|
||||
"unique_hunts": len(set(o["hunt_id"] for o in occurrences if o.get("hunt_id"))),
|
||||
}
|
||||
322
backend/app/api/routes/datasets.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""API routes for dataset upload, listing, and management."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
normalize_columns,
|
||||
normalize_rows,
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
filename: str
|
||||
source_tool: str | None = None
|
||||
row_count: int
|
||||
column_schema: dict | None = None
|
||||
normalized_columns: dict | None = None
|
||||
ioc_columns: dict | None = None
|
||||
file_size_bytes: int
|
||||
encoding: str | None = None
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
datasets: list[DatasetSummary]
|
||||
total: int
|
||||
|
||||
|
||||
class RowsResponse(BaseModel):
|
||||
rows: list[dict]
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
row_count: int
|
||||
columns: list[str]
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
name: str | None = Query(None, description="Display name for the dataset"),
|
||||
source_tool: str | None = Query(None, description="Source tool (e.g., velociraptor)"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type '{ext}' not allowed. Accepted: {', '.join(ALLOWED_EXTENSIONS)}",
|
||||
)
|
||||
|
||||
# Read file bytes
|
||||
raw_bytes = await file.read()
|
||||
if len(raw_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="File is empty")
|
||||
|
||||
if len(raw_bytes) > settings.max_upload_bytes:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Max size: {settings.MAX_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
|
||||
# Parse CSV
|
||||
try:
|
||||
rows, metadata = parse_csv_bytes(raw_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"CSV parse error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to parse CSV: {str(e)}. Check encoding and format.",
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=422, detail="CSV file contains no data rows")
|
||||
|
||||
columns: list[str] = metadata["columns"]
|
||||
column_types: dict = metadata["column_types"]
|
||||
|
||||
# Normalize columns
|
||||
column_mapping = normalize_columns(columns)
|
||||
normalized = normalize_rows(rows, column_mapping)
|
||||
|
||||
# Detect IOCs
|
||||
ioc_columns = detect_ioc_columns(columns, column_types, column_mapping)
|
||||
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
filename=file.filename,
|
||||
source_tool=source_tool,
|
||||
row_count=len(rows),
|
||||
column_schema=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
file_size_bytes=len(raw_bytes),
|
||||
encoding=metadata["encoding"],
|
||||
delimiter=metadata["delimiter"],
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
dataset_id=dataset.id,
|
||||
rows=rows,
|
||||
normalized_rows=normalized,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
)
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
row_count=len(rows),
|
||||
columns=columns,
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
datasets = await repo.list_datasets(hunt_id=hunt_id, limit=limit, offset=offset)
|
||||
total = await repo.count_datasets(hunt_id=hunt_id)
|
||||
|
||||
return DatasetListResponse(
|
||||
datasets=[
|
||||
DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
for ds in datasets
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}",
|
||||
response_model=DatasetSummary,
|
||||
summary="Get dataset details",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}/rows",
|
||||
response_model=RowsResponse,
|
||||
summary="Get dataset rows",
|
||||
)
|
||||
async def get_dataset_rows(
|
||||
dataset_id: str,
|
||||
limit: int = Query(1000, ge=1, le=10000),
|
||||
offset: int = Query(0, ge=0),
|
||||
normalized: bool = Query(False, description="Return normalized column names"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=limit, offset=offset)
|
||||
total = await repo.count_rows(dataset_id)
|
||||
|
||||
return RowsResponse(
|
||||
rows=[
|
||||
(r.normalized_data if normalized and r.normalized_data else r.data)
|
||||
for r in rows
|
||||
],
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{dataset_id}",
|
||||
summary="Delete a dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/rescan-ioc",
|
||||
summary="Re-scan IOC columns for all datasets",
|
||||
)
|
||||
async def rescan_ioc_columns(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
|
||||
repo = DatasetRepository(db)
|
||||
all_ds = await repo.list_datasets(limit=10000)
|
||||
updated = 0
|
||||
for ds in all_ds:
|
||||
columns = list((ds.column_schema or {}).keys())
|
||||
if not columns:
|
||||
continue
|
||||
new_ioc = detect_ioc_columns(
|
||||
columns,
|
||||
ds.column_schema or {},
|
||||
ds.normalized_columns or {},
|
||||
)
|
||||
if new_ioc != (ds.ioc_columns or {}):
|
||||
ds.ioc_columns = new_ioc
|
||||
updated += 1
|
||||
await db.commit()
|
||||
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}
|
||||
220
backend/app/api/routes/enrichment.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""API routes for IOC enrichment."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.enrichment import (
|
||||
enrichment_engine,
|
||||
IOCType,
|
||||
Verdict,
|
||||
EnrichmentResultData,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/enrichment", tags=["enrichment"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichIOCRequest(BaseModel):
|
||||
ioc_value: str = Field(..., max_length=2048, description="IOC value to enrich")
|
||||
ioc_type: str = Field(..., description="IOC type: ip, domain, hash_md5, hash_sha1, hash_sha256, url")
|
||||
skip_cache: bool = False
|
||||
|
||||
|
||||
class EnrichBatchRequest(BaseModel):
|
||||
iocs: list[dict] = Field(
|
||||
...,
|
||||
description="List of {value, type} pairs",
|
||||
max_length=50,
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentResultResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
source: str
|
||||
verdict: str
|
||||
score: float
|
||||
tags: list[str] = []
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
raw_data: dict = {}
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
class EnrichIOCResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
results: list[EnrichmentResultResponse]
|
||||
overall_verdict: str
|
||||
overall_score: float
|
||||
|
||||
|
||||
class EnrichBatchResponse(BaseModel):
|
||||
results: dict[str, list[EnrichmentResultResponse]]
|
||||
total_enriched: int
|
||||
|
||||
|
||||
def _to_response(r: EnrichmentResultData) -> EnrichmentResultResponse:
|
||||
return EnrichmentResultResponse(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
last_seen=r.last_seen,
|
||||
raw_data=r.raw_data,
|
||||
error=r.error,
|
||||
latency_ms=r.latency_ms,
|
||||
)
|
||||
|
||||
|
||||
def _compute_overall(results: list[EnrichmentResultData]) -> tuple[str, float]:
|
||||
"""Compute overall verdict from multiple provider results."""
|
||||
if not results:
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
verdicts = [r.verdict for r in results if r.verdict != Verdict.ERROR]
|
||||
if not verdicts:
|
||||
return Verdict.ERROR.value, 0.0
|
||||
|
||||
if Verdict.MALICIOUS in verdicts:
|
||||
return Verdict.MALICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.SUSPICIOUS in verdicts:
|
||||
return Verdict.SUSPICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.CLEAN in verdicts:
|
||||
return Verdict.CLEAN.value, 0.0
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ioc",
|
||||
response_model=EnrichIOCResponse,
|
||||
summary="Enrich a single IOC",
|
||||
description="Query all configured providers for an IOC (IP, hash, domain, URL).",
|
||||
)
|
||||
async def enrich_ioc(
|
||||
body: EnrichIOCRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
ioc_type = IOCType(body.ioc_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid IOC type: {body.ioc_type}. Valid: {[t.value for t in IOCType]}",
|
||||
)
|
||||
|
||||
results = await enrichment_engine.enrich_ioc(
|
||||
body.ioc_value, ioc_type, db=db, skip_cache=body.skip_cache,
|
||||
)
|
||||
|
||||
overall_verdict, overall_score = _compute_overall(results)
|
||||
|
||||
return EnrichIOCResponse(
|
||||
ioc_value=body.ioc_value,
|
||||
ioc_type=body.ioc_type,
|
||||
results=[_to_response(r) for r in results],
|
||||
overall_verdict=overall_verdict,
|
||||
overall_score=overall_score,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=EnrichBatchResponse,
|
||||
summary="Enrich a batch of IOCs",
|
||||
description="Enrich up to 50 IOCs at once across all providers.",
|
||||
)
|
||||
async def enrich_batch(
|
||||
body: EnrichBatchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
iocs = []
|
||||
for item in body.iocs:
|
||||
try:
|
||||
ioc_type = IOCType(item["type"])
|
||||
iocs.append((item["value"], ioc_type))
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not iocs:
|
||||
raise HTTPException(status_code=400, detail="No valid IOCs provided")
|
||||
|
||||
all_results = await enrichment_engine.enrich_batch(iocs, db=db)
|
||||
|
||||
return EnrichBatchResponse(
|
||||
results={
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
total_enriched=len(all_results),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/dataset/{dataset_id}",
|
||||
summary="Auto-enrich IOCs in a dataset",
|
||||
description="Automatically extract and enrich IOCs from a dataset's IOC columns.",
|
||||
)
|
||||
async def enrich_dataset(
|
||||
dataset_id: str,
|
||||
max_iocs: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if not dataset.ioc_columns:
|
||||
return {"message": "No IOC columns detected in this dataset", "results": {}}
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=1000)
|
||||
row_data = [r.data for r in rows]
|
||||
|
||||
all_results = await enrichment_engine.enrich_dataset_iocs(
|
||||
rows=row_data,
|
||||
ioc_columns=dataset.ioc_columns,
|
||||
db=db,
|
||||
max_iocs=max_iocs,
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset.name,
|
||||
"ioc_columns": dataset.ioc_columns,
|
||||
"results": {
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
"total_enriched": len(all_results),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
summary="Enrichment engine status",
|
||||
description="Check which providers are configured and available.",
|
||||
)
|
||||
async def enrichment_status():
|
||||
return enrichment_engine.status()
|
||||
158
backend/app/api/routes/hunts.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""API routes for hunt management."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
owner_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
dataset_count: int = 0
|
||||
hypothesis_count: int = 0
|
||||
|
||||
|
||||
class HuntListResponse(BaseModel):
|
||||
hunts: list[HuntResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
|
||||
hunt = Hunt(name=body.name, description=body.description)
|
||||
db.add(hunt)
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=HuntListResponse, summary="List hunts")
|
||||
async def list_hunts(
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
|
||||
if status:
|
||||
stmt = stmt.where(Hunt.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hunts = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hunt.id))
|
||||
if status:
|
||||
count_stmt = count_stmt.where(Hunt.status == status)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HuntListResponse(
|
||||
hunts=[
|
||||
HuntResponse(
|
||||
id=h.id,
|
||||
name=h.name,
|
||||
description=h.description,
|
||||
status=h.status,
|
||||
owner_id=h.owner_id,
|
||||
created_at=h.created_at.isoformat(),
|
||||
updated_at=h.updated_at.isoformat(),
|
||||
dataset_count=len(h.datasets) if h.datasets else 0,
|
||||
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
|
||||
)
|
||||
for h in hunts
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
|
||||
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
|
||||
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
if body.name is not None:
|
||||
hunt.name = body.name
|
||||
if body.description is not None:
|
||||
hunt.description = body.description
|
||||
if body.status is not None:
|
||||
hunt.status = body.status
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{hunt_id}", summary="Delete a hunt")
|
||||
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
await db.delete(hunt)
|
||||
return {"message": "Hunt deleted", "id": hunt_id}
|
||||
257
backend/app/api/routes/keywords.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ThemeUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
color: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class KeywordOut(BaseModel):
|
||||
id: int
|
||||
theme_id: str
|
||||
value: str
|
||||
is_regex: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class ThemeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
color: str
|
||||
enabled: bool
|
||||
is_builtin: bool
|
||||
created_at: str
|
||||
keyword_count: int
|
||||
keywords: list[KeywordOut]
|
||||
|
||||
|
||||
class ThemeListResponse(BaseModel):
|
||||
themes: list[ThemeOut]
|
||||
total: int
|
||||
|
||||
|
||||
class KeywordCreate(BaseModel):
|
||||
value: str = Field(..., min_length=1, max_length=256)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KeywordBulkCreate(BaseModel):
|
||||
values: list[str] = Field(..., min_items=1)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
total_hits: int
|
||||
hits: list[ScanHit]
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
return ThemeOut(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
color=t.color,
|
||||
enabled=t.enabled,
|
||||
is_builtin=t.is_builtin,
|
||||
created_at=t.created_at.isoformat(),
|
||||
keyword_count=len(t.keywords),
|
||||
keywords=[
|
||||
KeywordOut(
|
||||
id=k.id,
|
||||
theme_id=k.theme_id,
|
||||
value=k.value,
|
||||
is_regex=k.is_regex,
|
||||
created_at=k.created_at.isoformat(),
|
||||
)
|
||||
for k in t.keywords
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme.name = body.name
|
||||
if body.color is not None:
|
||||
theme.color = body.color
|
||||
if body.enabled is not None:
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
added = 0
|
||||
for val in body.values:
|
||||
val = val.strip()
|
||||
if not val:
|
||||
continue
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
69
backend/app/api/routes/network.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""API routes for Network Picture — deduplicated host inventory."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.network_inventory import build_network_picture
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HostEntry(BaseModel):
|
||||
hostname: str
|
||||
ips: list[str] = Field(default_factory=list)
|
||||
users: list[str] = Field(default_factory=list)
|
||||
os: list[str] = Field(default_factory=list)
|
||||
mac_addresses: list[str] = Field(default_factory=list)
|
||||
protocols: list[str] = Field(default_factory=list)
|
||||
open_ports: list[str] = Field(default_factory=list)
|
||||
remote_targets: list[str] = Field(default_factory=list)
|
||||
datasets: list[str] = Field(default_factory=list)
|
||||
connection_count: int = 0
|
||||
first_seen: str | None = None
|
||||
last_seen: str | None = None
|
||||
|
||||
|
||||
class PictureSummary(BaseModel):
|
||||
total_hosts: int = 0
|
||||
total_connections: int = 0
|
||||
total_unique_ips: int = 0
|
||||
datasets_scanned: int = 0
|
||||
|
||||
|
||||
class NetworkPictureResponse(BaseModel):
|
||||
hosts: list[HostEntry]
|
||||
summary: PictureSummary
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/picture",
|
||||
response_model=NetworkPictureResponse,
|
||||
summary="Build deduplicated host inventory for a hunt",
|
||||
description=(
|
||||
"Scans all datasets in the specified hunt, extracts host-identifying "
|
||||
"fields (hostname, IP, username, OS, MAC, ports), deduplicates by "
|
||||
"hostname, and returns a clean one-row-per-host network picture."
|
||||
),
|
||||
)
|
||||
async def get_network_picture(
|
||||
hunt_id: str = Query(..., description="Hunt ID to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return a deduplicated network picture for a hunt."""
|
||||
if not hunt_id:
|
||||
raise HTTPException(status_code=400, detail="hunt_id is required")
|
||||
|
||||
result = await build_network_picture(db, hunt_id)
|
||||
return result
|
||||
360
backend/app/api/routes/notebooks.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""API routes for investigation notebooks and playbooks."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Notebook, PlaybookRun, _new_id, _utcnow
|
||||
from app.services.playbook import (
|
||||
get_builtin_playbooks,
|
||||
get_playbook_template,
|
||||
validate_notebook_cells,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/notebooks", tags=["notebooks"])
|
||||
|
||||
|
||||
# ── Pydantic models ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NotebookCreate(BaseModel):
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class NotebookUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
cells: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
|
||||
|
||||
class CellUpdate(BaseModel):
|
||||
"""Update a single cell or add a new one."""
|
||||
cell_id: str
|
||||
cell_type: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class PlaybookStart(BaseModel):
|
||||
playbook_name: str
|
||||
hunt_id: Optional[str] = None
|
||||
case_id: Optional[str] = None
|
||||
started_by: Optional[str] = None
|
||||
|
||||
|
||||
class StepComplete(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
status: str = "completed" # completed | skipped
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _notebook_to_dict(nb: Notebook) -> dict:
|
||||
return {
|
||||
"id": nb.id,
|
||||
"title": nb.title,
|
||||
"description": nb.description,
|
||||
"cells": nb.cells or [],
|
||||
"hunt_id": nb.hunt_id,
|
||||
"case_id": nb.case_id,
|
||||
"owner_id": nb.owner_id,
|
||||
"tags": nb.tags or [],
|
||||
"cell_count": len(nb.cells or []),
|
||||
"created_at": nb.created_at.isoformat() if nb.created_at else None,
|
||||
"updated_at": nb.updated_at.isoformat() if nb.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _run_to_dict(run: PlaybookRun) -> dict:
|
||||
return {
|
||||
"id": run.id,
|
||||
"playbook_name": run.playbook_name,
|
||||
"status": run.status,
|
||||
"current_step": run.current_step,
|
||||
"total_steps": run.total_steps,
|
||||
"step_results": run.step_results or [],
|
||||
"hunt_id": run.hunt_id,
|
||||
"case_id": run.case_id,
|
||||
"started_by": run.started_by,
|
||||
"created_at": run.created_at.isoformat() if run.created_at else None,
|
||||
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
|
||||
"completed_at": run.completed_at.isoformat() if run.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ── Notebook CRUD ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", summary="List notebooks")
|
||||
async def list_notebooks(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Notebook)
|
||||
count_stmt = select(func.count(Notebook.id))
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Notebook.hunt_id == hunt_id)
|
||||
count_stmt = count_stmt.where(Notebook.hunt_id == hunt_id)
|
||||
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(Notebook.updated_at)).offset(offset).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"notebooks": [_notebook_to_dict(n) for n in results], "total": total}
|
||||
|
||||
|
||||
@router.get("/{notebook_id}", summary="Get notebook")
|
||||
async def get_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("", summary="Create notebook")
|
||||
async def create_notebook(body: NotebookCreate, db: AsyncSession = Depends(get_db)):
|
||||
cells = validate_notebook_cells(body.cells or [])
|
||||
if not cells:
|
||||
# Start with a default markdown cell
|
||||
cells = [{"id": "cell-0", "cell_type": "markdown", "source": "# Investigation Notes\n\nStart documenting your findings here.", "output": None, "metadata": {}}]
|
||||
|
||||
nb = Notebook(
|
||||
id=_new_id(),
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
cells=cells,
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
tags=body.tags,
|
||||
)
|
||||
db.add(nb)
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.put("/{notebook_id}", summary="Update notebook")
|
||||
async def update_notebook(
|
||||
notebook_id: str, body: NotebookUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
if body.title is not None:
|
||||
nb.title = body.title
|
||||
if body.description is not None:
|
||||
nb.description = body.description
|
||||
if body.cells is not None:
|
||||
nb.cells = validate_notebook_cells(body.cells)
|
||||
if body.tags is not None:
|
||||
nb.tags = body.tags
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.post("/{notebook_id}/cells", summary="Add or update a cell")
|
||||
async def upsert_cell(
|
||||
notebook_id: str, body: CellUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = list(nb.cells or [])
|
||||
found = False
|
||||
for i, c in enumerate(cells):
|
||||
if c.get("id") == body.cell_id:
|
||||
if body.cell_type is not None:
|
||||
cells[i]["cell_type"] = body.cell_type
|
||||
if body.source is not None:
|
||||
cells[i]["source"] = body.source
|
||||
if body.output is not None:
|
||||
cells[i]["output"] = body.output
|
||||
if body.metadata is not None:
|
||||
cells[i]["metadata"] = body.metadata
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
cells.append({
|
||||
"id": body.cell_id,
|
||||
"cell_type": body.cell_type or "markdown",
|
||||
"source": body.source or "",
|
||||
"output": body.output,
|
||||
"metadata": body.metadata or {},
|
||||
})
|
||||
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
await db.refresh(nb)
|
||||
return _notebook_to_dict(nb)
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}/cells/{cell_id}", summary="Delete a cell")
|
||||
async def delete_cell(
|
||||
notebook_id: str, cell_id: str, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
|
||||
cells = [c for c in (nb.cells or []) if c.get("id") != cell_id]
|
||||
nb.cells = cells
|
||||
await db.commit()
|
||||
return {"ok": True, "remaining_cells": len(cells)}
|
||||
|
||||
|
||||
@router.delete("/{notebook_id}", summary="Delete notebook")
|
||||
async def delete_notebook(notebook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
nb = await db.get(Notebook, notebook_id)
|
||||
if not nb:
|
||||
raise HTTPException(status_code=404, detail="Notebook not found")
|
||||
await db.delete(nb)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Playbooks ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/playbooks/templates", summary="List built-in playbook templates")
|
||||
async def list_playbook_templates():
|
||||
templates = get_builtin_playbooks()
|
||||
return {
|
||||
"templates": [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"category": t["category"],
|
||||
"tags": t["tags"],
|
||||
"step_count": len(t["steps"]),
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/playbooks/templates/{name}", summary="Get playbook template detail")
|
||||
async def get_playbook_template_detail(name: str):
|
||||
template = get_playbook_template(name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
return template
|
||||
|
||||
|
||||
@router.post("/playbooks/start", summary="Start a playbook run")
|
||||
async def start_playbook(body: PlaybookStart, db: AsyncSession = Depends(get_db)):
|
||||
template = get_playbook_template(body.playbook_name)
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="Playbook template not found")
|
||||
|
||||
run = PlaybookRun(
|
||||
id=_new_id(),
|
||||
playbook_name=body.playbook_name,
|
||||
status="in-progress",
|
||||
current_step=1,
|
||||
total_steps=len(template["steps"]),
|
||||
step_results=[],
|
||||
hunt_id=body.hunt_id,
|
||||
case_id=body.case_id,
|
||||
started_by=body.started_by,
|
||||
)
|
||||
db.add(run)
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.get("/playbooks/runs", summary="List playbook runs")
|
||||
async def list_playbook_runs(
|
||||
status: str | None = Query(None),
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(PlaybookRun)
|
||||
if status:
|
||||
stmt = stmt.where(PlaybookRun.status == status)
|
||||
if hunt_id:
|
||||
stmt = stmt.where(PlaybookRun.hunt_id == hunt_id)
|
||||
|
||||
results = (await db.execute(
|
||||
stmt.order_by(desc(PlaybookRun.created_at)).limit(limit)
|
||||
)).scalars().all()
|
||||
|
||||
return {"runs": [_run_to_dict(r) for r in results]}
|
||||
|
||||
|
||||
@router.get("/playbooks/runs/{run_id}", summary="Get playbook run detail")
|
||||
async def get_playbook_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
# Also include the template steps
|
||||
template = get_playbook_template(run.playbook_name)
|
||||
result = _run_to_dict(run)
|
||||
result["steps"] = template["steps"] if template else []
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/complete-step", summary="Complete current playbook step")
|
||||
async def complete_step(
|
||||
run_id: str, body: StepComplete, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
if run.status != "in-progress":
|
||||
raise HTTPException(status_code=400, detail="Run is not in progress")
|
||||
|
||||
step_results = list(run.step_results or [])
|
||||
step_results.append({
|
||||
"step": run.current_step,
|
||||
"status": body.status,
|
||||
"notes": body.notes,
|
||||
"completed_at": _utcnow().isoformat(),
|
||||
})
|
||||
run.step_results = step_results
|
||||
|
||||
if run.current_step >= run.total_steps:
|
||||
run.status = "completed"
|
||||
run.completed_at = _utcnow()
|
||||
else:
|
||||
run.current_step += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(run)
|
||||
return _run_to_dict(run)
|
||||
|
||||
|
||||
@router.post("/playbooks/runs/{run_id}/abort", summary="Abort a playbook run")
|
||||
async def abort_run(run_id: str, db: AsyncSession = Depends(get_db)):
|
||||
run = await db.get(PlaybookRun, run_id)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
run.status = "aborted"
|
||||
run.completed_at = _utcnow()
|
||||
await db.commit()
|
||||
return _run_to_dict(run)
|
||||
67
backend/app/api/routes/reports.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""API routes for report generation and export."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import HTMLResponse, PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.reports import report_generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/reports", tags=["reports"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}",
|
||||
summary="Generate hunt investigation report",
|
||||
description="Generate a comprehensive report for a hunt. Supports JSON, HTML, and CSV formats.",
|
||||
)
|
||||
async def generate_hunt_report(
|
||||
hunt_id: str,
|
||||
format: str = Query("json", description="Report format: json, html, csv"),
|
||||
include_rows: bool = Query(False, description="Include raw data rows"),
|
||||
max_rows: int = Query(500, ge=0, le=5000, description="Max rows to include"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format=format,
|
||||
include_rows=include_rows, max_rows=max_rows,
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
if format == "html":
|
||||
return HTMLResponse(content=result, headers={
|
||||
"Content-Disposition": f"inline; filename=threathunt_report_{hunt_id}.html",
|
||||
})
|
||||
elif format == "csv":
|
||||
return PlainTextResponse(content=result, media_type="text/csv", headers={
|
||||
"Content-Disposition": f"attachment; filename=threathunt_report_{hunt_id}.csv",
|
||||
})
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}/summary",
|
||||
summary="Quick hunt summary",
|
||||
description="Get a lightweight summary of the hunt for dashboard display.",
|
||||
)
|
||||
async def hunt_summary(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format="json", include_rows=False,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return {
|
||||
"hunt": result.get("hunt"),
|
||||
"summary": result.get("summary"),
|
||||
}
|
||||
121
backend/app/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Application configuration — single source of truth for all settings.
|
||||
|
||||
Loads from environment variables with sensible defaults for local dev.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class AppConfig(BaseSettings):
|
||||
"""Central configuration for the entire ThreatHunt application."""
|
||||
|
||||
# ── General ────────────────────────────────────────────────────────
|
||||
APP_NAME: str = "ThreatHunt"
|
||||
APP_VERSION: str = "0.4.0"
|
||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# ── Database ───────────────────────────────────────────────────────
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite+aiosqlite:///./threathunt.db",
|
||||
description="Async SQLAlchemy database URL. "
|
||||
"Use sqlite+aiosqlite:///./threathunt.db for local dev, "
|
||||
"postgresql+asyncpg://user:pass@host/db for production.",
|
||||
)
|
||||
|
||||
# ── CORS ───────────────────────────────────────────────────────────
|
||||
ALLOWED_ORIGINS: str = Field(
|
||||
default="http://localhost:3000,http://localhost:8000",
|
||||
description="Comma-separated list of allowed CORS origins",
|
||||
)
|
||||
|
||||
# ── File uploads ───────────────────────────────────────────────────
|
||||
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
||||
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
||||
|
||||
# ── LLM Cluster — Wile & Roadrunner ────────────────────────────────
|
||||
OPENWEBUI_URL: str = Field(
|
||||
default="https://ai.guapo613.beer",
|
||||
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
||||
)
|
||||
OPENWEBUI_API_KEY: str = Field(
|
||||
default="",
|
||||
description="API key for Open WebUI (if required)",
|
||||
)
|
||||
WILE_HOST: str = Field(
|
||||
default="100.110.190.12",
|
||||
description="Tailscale IP for Wile (heavy models)",
|
||||
)
|
||||
WILE_OLLAMA_PORT: int = Field(default=11434, description="Ollama port on Wile")
|
||||
ROADRUNNER_HOST: str = Field(
|
||||
default="100.110.190.11",
|
||||
description="Tailscale IP for Roadrunner (fast models + vision)",
|
||||
)
|
||||
ROADRUNNER_OLLAMA_PORT: int = Field(
|
||||
default=11434, description="Ollama port on Roadrunner"
|
||||
)
|
||||
|
||||
# ── LLM Routing defaults ──────────────────────────────────────────
|
||||
DEFAULT_FAST_MODEL: str = Field(
|
||||
default="llama3.1:latest",
|
||||
description="Default model for quick chat / simple queries",
|
||||
)
|
||||
DEFAULT_HEAVY_MODEL: str = Field(
|
||||
default="llama3.1:70b-instruct-q4_K_M",
|
||||
description="Default model for deep analysis / debate",
|
||||
)
|
||||
DEFAULT_CODE_MODEL: str = Field(
|
||||
default="qwen2.5-coder:32b",
|
||||
description="Default model for code / script analysis",
|
||||
)
|
||||
DEFAULT_VISION_MODEL: str = Field(
|
||||
default="llama3.2-vision:11b",
|
||||
description="Default model for image / screenshot analysis",
|
||||
)
|
||||
DEFAULT_EMBEDDING_MODEL: str = Field(
|
||||
default="bge-m3:latest",
|
||||
description="Default embedding model",
|
||||
)
|
||||
|
||||
# ── Agent behaviour ───────────────────────────────────────────────
|
||||
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
||||
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
||||
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
||||
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
||||
|
||||
# ── Enrichment API keys ───────────────────────────────────────────
|
||||
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
||||
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
||||
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────
|
||||
JWT_SECRET: str = Field(
|
||||
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
||||
description="Secret for JWT signing",
|
||||
)
|
||||
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
||||
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
||||
|
||||
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",") if o.strip()]
|
||||
|
||||
@property
|
||||
def wile_url(self) -> str:
|
||||
return f"http://{self.WILE_HOST}:{self.WILE_OLLAMA_PORT}"
|
||||
|
||||
@property
|
||||
def roadrunner_url(self) -> str:
|
||||
return f"http://{self.ROADRUNNER_HOST}:{self.ROADRUNNER_OLLAMA_PORT}"
|
||||
|
||||
@property
|
||||
def max_upload_bytes(self) -> int:
|
||||
return self.MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
|
||||
settings = AppConfig()
|
||||
12
backend/app/db/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Database package."""
|
||||
|
||||
from .engine import Base, get_db, init_db, dispose_db, engine, async_session_factory
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"get_db",
|
||||
"init_db",
|
||||
"dispose_db",
|
||||
"engine",
|
||||
"async_session_factory",
|
||||
]
|
||||
87
backend/app/db/engine.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
_is_sqlite = settings.DATABASE_URL.startswith("sqlite")
|
||||
|
||||
_engine_kwargs: dict = dict(
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
||||
_engine_kwargs["pool_size"] = 1
|
||||
_engine_kwargs["max_overflow"] = 0
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
"""Enable WAL mode and tune busy-timeout for SQLite connections."""
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# Only create tables that don't already exist (safe alongside Alembic)
|
||||
def _create_missing(sync_conn):
|
||||
inspector = sa_inspect(sync_conn)
|
||||
existing = set(inspector.get_table_names())
|
||||
tables_to_create = [
|
||||
t for t in Base.metadata.sorted_tables
|
||||
if t.name not in existing
|
||||
]
|
||||
Base.metadata.create_all(sync_conn, tables=tables_to_create)
|
||||
|
||||
await conn.run_sync(_create_missing)
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
await engine.dispose()
|
||||
546
backend/app/db/models.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, and users.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
Index,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .engine import Base
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
# ── Users ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
|
||||
|
||||
|
||||
# ── Hunts ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Hunt(Base):
|
||||
__tablename__ = "hunts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
|
||||
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
|
||||
|
||||
# ── Datasets ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
filename: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
|
||||
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
|
||||
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
|
||||
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
|
||||
rows: Mapped[list["DatasetRow"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_datasets_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
class DatasetRow(Base):
|
||||
"""Individual row from a CSV dataset, stored as JSON blob."""
|
||||
__tablename__ = "dataset_rows"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
row_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
data: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# relationships
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(
|
||||
back_populates="row", lazy="noload"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_dataset_rows_dataset", "dataset_id"),
|
||||
Index("ix_dataset_rows_dataset_idx", "dataset_id", "row_index"),
|
||||
)
|
||||
|
||||
|
||||
# ── Conversations ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
|
||||
order_by="Message.created_at",
|
||||
)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
|
||||
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_messages_conversation", "conversation_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Annotations ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Annotation(Base):
|
||||
__tablename__ = "annotations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
row_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("dataset_rows.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
author_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
severity: Mapped[str] = mapped_column(
|
||||
String(16), default="info"
|
||||
) # info | low | medium | high | critical
|
||||
tag: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), nullable=True
|
||||
) # suspicious | benign | needs-review
|
||||
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
|
||||
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_annotations_dataset", "dataset_id"),
|
||||
Index("ix_annotations_row", "row_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Hypotheses ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Hypothesis(Base):
|
||||
__tablename__ = "hypotheses"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), default="draft"
|
||||
) # draft | active | confirmed | rejected
|
||||
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_hypotheses_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Results ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentResult(Base):
|
||||
__tablename__ = "enrichment_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
|
||||
ioc_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False
|
||||
) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
|
||||
verdict: Mapped[Optional[str]] = mapped_column(
|
||||
String(16), nullable=True
|
||||
) # clean | suspicious | malicious | unknown
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
cached_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_enrichment_ioc_source", "ioc_value", "source"),
|
||||
)
|
||||
|
||||
|
||||
# ── AUP Keyword Themes & Keywords ────────────────────────────────────
|
||||
|
||||
|
||||
class KeywordTheme(Base):
|
||||
"""A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
|
||||
__tablename__ = "keyword_themes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
keywords: Mapped[list["Keyword"]] = relationship(
|
||||
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Keyword(Base):
|
||||
"""Individual keyword / pattern belonging to a theme."""
|
||||
__tablename__ = "keywords"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
theme_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("keyword_themes.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
value: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_keywords_theme", "theme_id"),
|
||||
Index("ix_keywords_value", "value"),
|
||||
)
|
||||
|
||||
|
||||
# ── Cases ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Case(Base):
|
||||
"""Incident / investigation case, inspired by TheHive."""
|
||||
__tablename__ = "cases"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
severity: Mapped[str] = mapped_column(String(16), default="medium") # info|low|medium|high|critical
|
||||
tlp: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
|
||||
pap: Mapped[str] = mapped_column(String(16), default="amber") # white|green|amber|red
|
||||
status: Mapped[str] = mapped_column(String(24), default="open") # open|in-progress|resolved|closed
|
||||
priority: Mapped[int] = mapped_column(Integer, default=2) # 1(urgent)..4(low)
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
mitre_techniques: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
iocs: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{type, value, description}]
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
tasks: Mapped[list["CaseTask"]] = relationship(
|
||||
back_populates="case", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_cases_hunt", "hunt_id"),
|
||||
Index("ix_cases_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class CaseTask(Base):
|
||||
"""Task within a case (Kanban board item)."""
|
||||
__tablename__ = "case_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
case_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("cases.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(24), default="todo") # todo|in-progress|done
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
order: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
case: Mapped["Case"] = relationship(back_populates="tasks")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_case_tasks_case", "case_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Activity Log ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActivityLog(Base):
|
||||
"""Audit trail / activity log for cases and hunts."""
|
||||
__tablename__ = "activity_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
entity_type: Mapped[str] = mapped_column(String(32), nullable=False) # case|hunt|annotation
|
||||
entity_id: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
action: Mapped[str] = mapped_column(String(64), nullable=False) # created|updated|status_changed|etc
|
||||
details: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_activity_entity", "entity_type", "entity_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Alerts ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Alert(Base):
|
||||
"""Security alert generated by analyzers or rules."""
|
||||
__tablename__ = "alerts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
severity: Mapped[str] = mapped_column(String(16), default="medium") # critical|high|medium|low|info
|
||||
status: Mapped[str] = mapped_column(String(24), default="new") # new|acknowledged|in-progress|resolved|false-positive
|
||||
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # which analyzer produced it
|
||||
score: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
evidence: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{row_index, field, value, ...}]
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
assignee: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
acknowledged_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
resolved_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alerts_severity", "severity"),
|
||||
Index("ix_alerts_status", "status"),
|
||||
Index("ix_alerts_hunt", "hunt_id"),
|
||||
Index("ix_alerts_dataset", "dataset_id"),
|
||||
)
|
||||
|
||||
|
||||
class AlertRule(Base):
|
||||
"""User-defined alert rule (triggers analyzers automatically on upload)."""
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
analyzer: Mapped[str] = mapped_column(String(64), nullable=False) # analyzer name
|
||||
config: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # analyzer config overrides
|
||||
severity_override: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
) # None = global
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_alert_rules_analyzer", "analyzer"),
|
||||
)
|
||||
|
||||
|
||||
# ── Notebooks ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Notebook(Base):
|
||||
"""Investigation notebook — cell-based document for analyst notes and queries."""
|
||||
__tablename__ = "notebooks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
cells: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{id, cell_type, source, output, metadata}]
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_notebooks_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Playbook Runs ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PlaybookRun(Base):
|
||||
"""Record of a playbook execution (links a template to a hunt/case)."""
|
||||
__tablename__ = "playbook_runs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
playbook_name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(24), default="in-progress") # in-progress | completed | aborted
|
||||
current_step: Mapped[int] = mapped_column(Integer, default=1)
|
||||
total_steps: Mapped[int] = mapped_column(Integer, default=0)
|
||||
step_results: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) # [{step, status, notes, completed_at}]
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
case_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("cases.id"), nullable=True
|
||||
)
|
||||
started_by: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_playbook_runs_hunt", "hunt_id"),
|
||||
Index("ix_playbook_runs_status", "status"),
|
||||
)
|
||||
1
backend/app/db/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repositories package — typed CRUD operations for each model."""
|
||||
127
backend/app/db/repositories/datasets.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Dataset repository — CRUD operations for datasets and their rows."""
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRepository:
|
||||
"""Typed CRUD for Dataset and DatasetRow models."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
# ── Dataset CRUD ──────────────────────────────────────────────────
|
||||
|
||||
async def create_dataset(self, **kwargs) -> Dataset:
|
||||
ds = Dataset(**kwargs)
|
||||
self.session.add(ds)
|
||||
await self.session.flush()
|
||||
return ds
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset | None:
|
||||
result = await self.session.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
hunt_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Sequence[Dataset]:
|
||||
stmt = select(Dataset).order_by(Dataset.created_at.desc())
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_datasets(self, hunt_id: str | None = None) -> int:
|
||||
stmt = select(func.count(Dataset.id))
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def delete_dataset(self, dataset_id: str) -> bool:
|
||||
ds = await self.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
return False
|
||||
await self.session.delete(ds)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
||||
# ── Row CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
async def bulk_insert_rows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows: list[dict],
|
||||
normalized_rows: list[dict] | None = None,
|
||||
batch_size: int = 500,
|
||||
) -> int:
|
||||
"""Insert rows in batches. Returns count inserted."""
|
||||
count = 0
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
norm_batch = normalized_rows[i : i + batch_size] if normalized_rows else [None] * len(batch)
|
||||
objects = [
|
||||
DatasetRow(
|
||||
dataset_id=dataset_id,
|
||||
row_index=i + j,
|
||||
data=row,
|
||||
normalized_data=norm,
|
||||
)
|
||||
for j, (row, norm) in enumerate(zip(batch, norm_batch))
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
await self.session.flush()
|
||||
count += len(objects)
|
||||
return count
|
||||
|
||||
async def get_rows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
limit: int = 1000,
|
||||
offset: int = 0,
|
||||
) -> Sequence[DatasetRow]:
|
||||
stmt = (
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_rows(self, dataset_id: str) -> int:
|
||||
stmt = select(func.count(DatasetRow.id)).where(
|
||||
DatasetRow.dataset_id == dataset_id
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_row_by_index(
|
||||
self, dataset_id: str, row_index: int
|
||||
) -> DatasetRow | None:
|
||||
stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset_id,
|
||||
DatasetRow.row_index == row_index,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_rows(self, dataset_id: str) -> int:
|
||||
result = await self.session.execute(
|
||||
delete(DatasetRow).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
return result.rowcount # type: ignore[return-value]
|
||||
102
backend/app/main.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""ThreatHunt backend application.
|
||||
|
||||
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
|
||||
annotation/hypothesis routes. DB tables are auto-created on startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.db import init_db, dispose_db
|
||||
from app.api.routes.agent_v2 import router as agent_router
|
||||
from app.api.routes.datasets import router as datasets_router
|
||||
from app.api.routes.hunts import router as hunts_router
|
||||
from app.api.routes.annotations import ann_router, hyp_router
|
||||
from app.api.routes.enrichment import router as enrichment_router
|
||||
from app.api.routes.correlation import router as correlation_router
|
||||
from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.network import router as network_router
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.cases import router as cases_router
|
||||
from app.api.routes.alerts import router as alerts_router
|
||||
from app.api.routes.notebooks import router as notebooks_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown lifecycle."""
|
||||
logger.info("Starting ThreatHunt API …")
|
||||
await init_db()
|
||||
logger.info("Database initialised")
|
||||
# Seed default AUP keyword themes
|
||||
from app.db import async_session_factory
|
||||
from app.services.keyword_defaults import seed_defaults
|
||||
async with async_session_factory() as seed_db:
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
yield
|
||||
logger.info("Shutting down …")
|
||||
from app.agents.providers_v2 import cleanup_client
|
||||
from app.services.enrichment import enrichment_engine
|
||||
await cleanup_client()
|
||||
await enrichment_engine.cleanup()
|
||||
await dispose_db()
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="ThreatHunt API",
|
||||
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
|
||||
version="0.3.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routes
|
||||
app.include_router(auth_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(datasets_router)
|
||||
app.include_router(hunts_router)
|
||||
app.include_router(ann_router)
|
||||
app.include_router(hyp_router)
|
||||
app.include_router(enrichment_router)
|
||||
app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(network_router)
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(cases_router)
|
||||
app.include_router(alerts_router)
|
||||
app.include_router(notebooks_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
async def root():
|
||||
"""API health check."""
|
||||
return {
|
||||
"service": "ThreatHunt API",
|
||||
"version": settings.APP_VERSION,
|
||||
"status": "running",
|
||||
"docs": "/docs",
|
||||
"cluster": {
|
||||
"wile": settings.wile_url,
|
||||
"roadrunner": settings.roadrunner_url,
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Services package."""
|
||||
464
backend/app/services/analyzers.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Pluggable Analyzer Framework for ThreatHunt.
|
||||
|
||||
Each analyzer implements a simple protocol:
|
||||
- name / description properties
|
||||
- async analyze(rows, config) -> list[AlertCandidate]
|
||||
|
||||
The AnalyzerRegistry discovers and runs all enabled analyzers against
|
||||
a dataset, producing alert candidates that the alert system can persist.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Alert Candidate DTO ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertCandidate:
|
||||
"""A single finding from an analyzer, before it becomes a persisted Alert."""
|
||||
analyzer: str
|
||||
title: str
|
||||
severity: str # critical | high | medium | low | info
|
||||
description: str
|
||||
evidence: list[dict] = field(default_factory=list) # [{row_index, field, value, ...}]
|
||||
mitre_technique: Optional[str] = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
score: float = 0.0 # 0-100
|
||||
|
||||
|
||||
# ── Base Analyzer ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BaseAnalyzer(ABC):
|
||||
"""Interface every analyzer must implement."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
async def analyze(
|
||||
self, rows: list[dict[str, Any]], config: dict[str, Any] | None = None
|
||||
) -> list[AlertCandidate]: ...
|
||||
|
||||
|
||||
# ── Built-in Analyzers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class EntropyAnalyzer(BaseAnalyzer):
|
||||
"""Detects high-entropy strings (encoded payloads, obfuscated commands)."""
|
||||
|
||||
name = "entropy"
|
||||
description = "Flags fields with high Shannon entropy (possible encoding/obfuscation)"
|
||||
|
||||
ENTROPY_FIELDS = [
|
||||
"command_line", "commandline", "process_command_line", "cmdline",
|
||||
"powershell_command", "script_block", "url", "uri", "path",
|
||||
"file_path", "target_filename", "query", "dns_query",
|
||||
]
|
||||
DEFAULT_THRESHOLD = 4.5
|
||||
|
||||
@staticmethod
|
||||
def _shannon(s: str) -> float:
|
||||
if not s or len(s) < 8:
|
||||
return 0.0
|
||||
freq = Counter(s)
|
||||
length = len(s)
|
||||
return -sum((c / length) * math.log2(c / length) for c in freq.values())
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
threshold = config.get("entropy_threshold", self.DEFAULT_THRESHOLD)
|
||||
min_length = config.get("min_length", 20)
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
for field_name in self.ENTROPY_FIELDS:
|
||||
val = str(row.get(field_name, ""))
|
||||
if len(val) < min_length:
|
||||
continue
|
||||
ent = self._shannon(val)
|
||||
if ent >= threshold:
|
||||
sev = "critical" if ent > 5.5 else "high" if ent > 5.0 else "medium"
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"High-entropy string in {field_name}",
|
||||
severity=sev,
|
||||
description=f"Shannon entropy {ent:.2f} (threshold {threshold}) in row {idx}, field '{field_name}'",
|
||||
evidence=[{"row_index": idx, "field": field_name, "value": val[:200], "entropy": round(ent, 3)}],
|
||||
mitre_technique="T1027", # Obfuscated Files or Information
|
||||
tags=["obfuscation", "entropy"],
|
||||
score=min(100, ent * 18),
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
class SuspiciousCommandAnalyzer(BaseAnalyzer):
|
||||
"""Detects known-bad command patterns (credential dumping, lateral movement, persistence)."""
|
||||
|
||||
name = "suspicious_commands"
|
||||
description = "Flags processes executing known-suspicious command patterns"
|
||||
|
||||
PATTERNS: list[tuple[str, str, str, str]] = [
|
||||
# (regex, title, severity, mitre_technique)
|
||||
(r"mimikatz|sekurlsa|lsadump|kerberos::list", "Mimikatz / Credential Dumping", "critical", "T1003"),
|
||||
(r"(?i)-enc\s+[A-Za-z0-9+/=]{40,}", "Encoded PowerShell command", "high", "T1059.001"),
|
||||
(r"(?i)invoke-(mimikatz|expression|webrequest|shellcode)", "Suspicious PowerShell Invoke", "high", "T1059.001"),
|
||||
(r"(?i)net\s+(user|localgroup|group)\s+/add", "Local account creation", "high", "T1136.001"),
|
||||
(r"(?i)schtasks\s+/create", "Scheduled task creation", "medium", "T1053.005"),
|
||||
(r"(?i)reg\s+add\s+.*\\run", "Registry Run key persistence", "high", "T1547.001"),
|
||||
(r"(?i)wmic\s+.*(process\s+call|shadowcopy\s+delete)", "WMI abuse / shadow copy deletion", "critical", "T1047"),
|
||||
(r"(?i)psexec|winrm|wmic\s+/node:", "Lateral movement tool", "high", "T1021"),
|
||||
(r"(?i)certutil\s+-urlcache", "Certutil download (LOLBin)", "high", "T1105"),
|
||||
(r"(?i)bitsadmin\s+/transfer", "BITSAdmin download", "medium", "T1197"),
|
||||
(r"(?i)vssadmin\s+delete\s+shadows", "VSS shadow deletion (ransomware)", "critical", "T1490"),
|
||||
(r"(?i)bcdedit.*recoveryenabled.*no", "Boot config tamper (ransomware)", "critical", "T1490"),
|
||||
(r"(?i)attrib\s+\+h\s+\+s", "Hidden file attribute set", "low", "T1564.001"),
|
||||
(r"(?i)netsh\s+advfirewall\s+.*disable", "Firewall disabled", "high", "T1562.004"),
|
||||
(r"(?i)whoami\s*/priv", "Privilege enumeration", "medium", "T1033"),
|
||||
(r"(?i)nltest\s+/dclist", "Domain controller enumeration", "medium", "T1018"),
|
||||
(r"(?i)dsquery|ldapsearch|adfind", "Active Directory enumeration", "medium", "T1087.002"),
|
||||
(r"(?i)procdump.*-ma\s+lsass", "LSASS memory dump", "critical", "T1003.001"),
|
||||
(r"(?i)rundll32.*comsvcs.*MiniDump", "LSASS dump via comsvcs", "critical", "T1003.001"),
|
||||
]
|
||||
|
||||
CMD_FIELDS = [
|
||||
"command_line", "commandline", "process_command_line", "cmdline",
|
||||
"parent_command_line", "powershell_command",
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
alerts: list[AlertCandidate] = []
|
||||
compiled = [(re.compile(p, re.IGNORECASE), t, s, m) for p, t, s, m in self.PATTERNS]
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
for fld in self.CMD_FIELDS:
|
||||
val = str(row.get(fld, ""))
|
||||
if len(val) < 3:
|
||||
continue
|
||||
for pattern, title, sev, mitre in compiled:
|
||||
if pattern.search(val):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=title,
|
||||
severity=sev,
|
||||
description=f"Suspicious command pattern in row {idx}: {val[:200]}",
|
||||
evidence=[{"row_index": idx, "field": fld, "value": val[:300]}],
|
||||
mitre_technique=mitre,
|
||||
tags=["command", "suspicious"],
|
||||
score={"critical": 95, "high": 80, "medium": 60, "low": 30}.get(sev, 50),
|
||||
))
|
||||
return alerts
|
||||
|
||||
|
||||
class NetworkAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects anomalous network patterns (beaconing, unusual ports, large transfers)."""
|
||||
|
||||
name = "network_anomaly"
|
||||
description = "Flags anomalous network behavior (beaconing, unusual ports, large transfers)"
|
||||
|
||||
SUSPICIOUS_PORTS = {4444, 5555, 6666, 8888, 9999, 1234, 31337, 12345, 54321, 1337}
|
||||
C2_PORTS = {443, 8443, 8080, 4443, 9443}
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
# Track destination IP frequency for beaconing detection
|
||||
dst_freq: dict[str, list[int]] = defaultdict(list)
|
||||
port_hits: list[tuple[int, str, int]] = []
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
dst_ip = str(row.get("dst_ip", row.get("destination_ip", row.get("dest_ip", ""))))
|
||||
dst_port = row.get("dst_port", row.get("destination_port", row.get("dest_port", "")))
|
||||
|
||||
if dst_ip and dst_ip != "":
|
||||
dst_freq[dst_ip].append(idx)
|
||||
|
||||
if dst_port:
|
||||
try:
|
||||
port_num = int(dst_port)
|
||||
if port_num in self.SUSPICIOUS_PORTS:
|
||||
port_hits.append((idx, dst_ip, port_num))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Large transfer detection
|
||||
bytes_val = row.get("bytes_sent", row.get("bytes_out", row.get("sent_bytes", 0)))
|
||||
try:
|
||||
if int(bytes_val or 0) > config.get("large_transfer_threshold", 10_000_000):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title="Large data transfer detected",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: {bytes_val} bytes sent to {dst_ip}",
|
||||
evidence=[{"row_index": idx, "dst_ip": dst_ip, "bytes": str(bytes_val)}],
|
||||
mitre_technique="T1048",
|
||||
tags=["exfiltration", "network"],
|
||||
score=65,
|
||||
))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Beaconing: IPs contacted more than threshold times
|
||||
beacon_thresh = config.get("beacon_threshold", 20)
|
||||
for ip, indices in dst_freq.items():
|
||||
if len(indices) >= beacon_thresh:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Possible beaconing to {ip}",
|
||||
severity="high",
|
||||
description=f"Destination {ip} contacted {len(indices)} times (threshold: {beacon_thresh})",
|
||||
evidence=[{"dst_ip": ip, "contact_count": len(indices), "sample_rows": indices[:10]}],
|
||||
mitre_technique="T1071",
|
||||
tags=["beaconing", "c2", "network"],
|
||||
score=min(95, 50 + len(indices)),
|
||||
))
|
||||
|
||||
# Suspicious ports
|
||||
for idx, ip, port in port_hits:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Connection on suspicious port {port}",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: connection to {ip}:{port}",
|
||||
evidence=[{"row_index": idx, "dst_ip": ip, "dst_port": port}],
|
||||
mitre_technique="T1571",
|
||||
tags=["suspicious_port", "network"],
|
||||
score=55,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class FrequencyAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects statistically rare values that may indicate anomalies."""
|
||||
|
||||
name = "frequency_anomaly"
|
||||
description = "Flags statistically rare field values (potential anomalies)"
|
||||
|
||||
FIELDS_TO_CHECK = [
|
||||
"process_name", "image_name", "parent_process_name",
|
||||
"user", "username", "user_name",
|
||||
"event_type", "action", "status",
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
rarity_threshold = config.get("rarity_threshold", 0.01) # <1% occurrence
|
||||
min_rows = config.get("min_rows", 50)
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
if len(rows) < min_rows:
|
||||
return alerts
|
||||
|
||||
for fld in self.FIELDS_TO_CHECK:
|
||||
values = [str(row.get(fld, "")) for row in rows if row.get(fld)]
|
||||
if not values:
|
||||
continue
|
||||
counts = Counter(values)
|
||||
total = len(values)
|
||||
|
||||
for val, cnt in counts.items():
|
||||
pct = cnt / total
|
||||
if pct <= rarity_threshold and cnt <= 3:
|
||||
# Find row indices
|
||||
indices = [i for i, r in enumerate(rows) if str(r.get(fld, "")) == val]
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Rare {fld}: {val[:80]}",
|
||||
severity="low",
|
||||
description=f"'{val}' appears {cnt}/{total} times ({pct:.2%}) in field '{fld}'",
|
||||
evidence=[{"field": fld, "value": val[:200], "count": cnt, "total": total, "rows": indices[:5]}],
|
||||
tags=["anomaly", "rare"],
|
||||
score=max(20, 50 - (pct * 5000)),
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class AuthAnomalyAnalyzer(BaseAnalyzer):
|
||||
"""Detects authentication anomalies (brute force, unusual logon types)."""
|
||||
|
||||
name = "auth_anomaly"
|
||||
description = "Flags authentication anomalies (failed logins, unusual logon types)"
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
config = config or {}
|
||||
alerts: list[AlertCandidate] = []
|
||||
|
||||
# Track failed logins per user
|
||||
failed_by_user: dict[str, list[int]] = defaultdict(list)
|
||||
logon_types: dict[str, list[int]] = defaultdict(list)
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
event_type = str(row.get("event_type", row.get("action", ""))).lower()
|
||||
status = str(row.get("status", row.get("result", ""))).lower()
|
||||
user = str(row.get("username", row.get("user", row.get("user_name", ""))))
|
||||
logon_type = str(row.get("logon_type", ""))
|
||||
|
||||
if "logon" in event_type or "auth" in event_type or "login" in event_type:
|
||||
if "fail" in status or "4625" in str(row.get("event_id", "")):
|
||||
if user:
|
||||
failed_by_user[user].append(idx)
|
||||
|
||||
if logon_type in ("3", "10"): # Network/RemoteInteractive
|
||||
logon_types[logon_type].append(idx)
|
||||
|
||||
# Brute force: >5 failed logins for same user
|
||||
brute_thresh = config.get("brute_force_threshold", 5)
|
||||
for user, indices in failed_by_user.items():
|
||||
if len(indices) >= brute_thresh:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Possible brute force: {user}",
|
||||
severity="high",
|
||||
description=f"User '{user}' had {len(indices)} failed logins",
|
||||
evidence=[{"user": user, "failed_count": len(indices), "rows": indices[:10]}],
|
||||
mitre_technique="T1110",
|
||||
tags=["brute_force", "authentication"],
|
||||
score=min(90, 50 + len(indices) * 3),
|
||||
))
|
||||
|
||||
# Unusual logon types
|
||||
for ltype, indices in logon_types.items():
|
||||
label = "Network logon (Type 3)" if ltype == "3" else "Remote Desktop (Type 10)"
|
||||
if len(indices) >= 3:
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"{label} detected",
|
||||
severity="medium" if ltype == "3" else "high",
|
||||
description=f"{len(indices)} {label} events detected",
|
||||
evidence=[{"logon_type": ltype, "count": len(indices), "rows": indices[:10]}],
|
||||
mitre_technique="T1021",
|
||||
tags=["authentication", "lateral_movement"],
|
||||
score=55 if ltype == "3" else 70,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
class PersistenceAnalyzer(BaseAnalyzer):
|
||||
"""Detects persistence mechanisms (registry keys, services, scheduled tasks)."""
|
||||
|
||||
name = "persistence"
|
||||
description = "Flags persistence mechanism installations"
|
||||
|
||||
REGISTRY_PATTERNS = [
|
||||
(r"(?i)\\CurrentVersion\\Run", "Run key persistence", "T1547.001"),
|
||||
(r"(?i)\\Services\\", "Service installation", "T1543.003"),
|
||||
(r"(?i)\\Winlogon\\", "Winlogon persistence", "T1547.004"),
|
||||
(r"(?i)\\Image File Execution Options\\", "IFEO debugger persistence", "T1546.012"),
|
||||
(r"(?i)\\Explorer\\Shell Folders", "Shell folder hijack", "T1547.001"),
|
||||
]
|
||||
|
||||
async def analyze(self, rows, config=None):
|
||||
alerts: list[AlertCandidate] = []
|
||||
compiled = [(re.compile(p), t, m) for p, t, m in self.REGISTRY_PATTERNS]
|
||||
|
||||
for idx, row in enumerate(rows):
|
||||
# Check registry paths
|
||||
reg_path = str(row.get("registry_key", row.get("target_object", row.get("registry_path", ""))))
|
||||
for pattern, title, mitre in compiled:
|
||||
if pattern.search(reg_path):
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=title,
|
||||
severity="high",
|
||||
description=f"Row {idx}: {reg_path[:200]}",
|
||||
evidence=[{"row_index": idx, "registry_key": reg_path[:300]}],
|
||||
mitre_technique=mitre,
|
||||
tags=["persistence", "registry"],
|
||||
score=75,
|
||||
))
|
||||
|
||||
# Check for service creation events
|
||||
event_type = str(row.get("event_type", "")).lower()
|
||||
if "service" in event_type and "creat" in event_type:
|
||||
svc_name = row.get("service_name", row.get("target_filename", "unknown"))
|
||||
alerts.append(AlertCandidate(
|
||||
analyzer=self.name,
|
||||
title=f"Service created: {svc_name}",
|
||||
severity="medium",
|
||||
description=f"Row {idx}: New service '{svc_name}' created",
|
||||
evidence=[{"row_index": idx, "service_name": str(svc_name)}],
|
||||
mitre_technique="T1543.003",
|
||||
tags=["persistence", "service"],
|
||||
score=60,
|
||||
))
|
||||
|
||||
return alerts
|
||||
|
||||
|
||||
# ── Analyzer Registry ────────────────────────────────────────────────
|
||||
|
||||
|
||||
_ALL_ANALYZERS: list[BaseAnalyzer] = [
|
||||
EntropyAnalyzer(),
|
||||
SuspiciousCommandAnalyzer(),
|
||||
NetworkAnomalyAnalyzer(),
|
||||
FrequencyAnomalyAnalyzer(),
|
||||
AuthAnomalyAnalyzer(),
|
||||
PersistenceAnalyzer(),
|
||||
]
|
||||
|
||||
|
||||
def get_available_analyzers() -> list[dict[str, str]]:
|
||||
"""Return metadata about all registered analyzers."""
|
||||
return [{"name": a.name, "description": a.description} for a in _ALL_ANALYZERS]
|
||||
|
||||
|
||||
def get_analyzer(name: str) -> BaseAnalyzer | None:
|
||||
"""Get an analyzer by name."""
|
||||
for a in _ALL_ANALYZERS:
|
||||
if a.name == name:
|
||||
return a
|
||||
return None
|
||||
|
||||
|
||||
async def run_all_analyzers(
|
||||
rows: list[dict[str, Any]],
|
||||
enabled: list[str] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> list[AlertCandidate]:
|
||||
"""Run all (or selected) analyzers and return combined alert candidates.
|
||||
|
||||
Args:
|
||||
rows: Flat list of row dicts (normalized_data or data from DatasetRow).
|
||||
enabled: Optional list of analyzer names to run. Runs all if None.
|
||||
config: Optional config overrides passed to each analyzer.
|
||||
|
||||
Returns:
|
||||
Combined list of AlertCandidate from all analyzers, sorted by score desc.
|
||||
"""
|
||||
config = config or {}
|
||||
results: list[AlertCandidate] = []
|
||||
|
||||
for analyzer in _ALL_ANALYZERS:
|
||||
if enabled and analyzer.name not in enabled:
|
||||
continue
|
||||
try:
|
||||
candidates = await analyzer.analyze(rows, config)
|
||||
results.extend(candidates)
|
||||
logger.info("Analyzer %s produced %d alerts", analyzer.name, len(candidates))
|
||||
except Exception:
|
||||
logger.exception("Analyzer %s failed", analyzer.name)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda a: a.score, reverse=True)
|
||||
return results
|
||||
199
backend/app/services/anomaly_detector.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Embedding-based anomaly detection using Roadrunner's bge-m3 model.
|
||||
|
||||
Converts dataset rows to embeddings, clusters them, and flags outliers
|
||||
that deviate significantly from the cluster centroids. Uses cosine
|
||||
distance and simple k-means-like centroid computation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import AnomalyResult, Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBED_URL = f"{settings.roadrunner_url}/api/embed"
|
||||
EMBED_MODEL = "bge-m3"
|
||||
BATCH_SIZE = 32 # rows per embedding batch
|
||||
MAX_ROWS = 2000 # cap for anomaly detection
|
||||
|
||||
# --- math helpers (no numpy required) ---
|
||||
|
||||
def _dot(a: list[float], b: list[float]) -> float:
|
||||
return sum(x * y for x, y in zip(a, b))
|
||||
|
||||
|
||||
def _norm(v: list[float]) -> float:
|
||||
return math.sqrt(sum(x * x for x in v))
|
||||
|
||||
|
||||
def _cosine_distance(a: list[float], b: list[float]) -> float:
|
||||
na, nb = _norm(a), _norm(b)
|
||||
if na == 0 or nb == 0:
|
||||
return 1.0
|
||||
return 1.0 - _dot(a, b) / (na * nb)
|
||||
|
||||
|
||||
def _mean_vector(vectors: list[list[float]]) -> list[float]:
|
||||
if not vectors:
|
||||
return []
|
||||
dim = len(vectors[0])
|
||||
n = len(vectors)
|
||||
return [sum(v[i] for v in vectors) / n for i in range(dim)]
|
||||
|
||||
|
||||
def _row_to_text(data: dict) -> str:
|
||||
"""Flatten a row dict to a single string for embedding."""
|
||||
parts = []
|
||||
for k, v in data.items():
|
||||
sv = str(v).strip()
|
||||
if sv and sv.lower() not in ('none', 'null', ''):
|
||||
parts.append(f"{k}={sv}")
|
||||
return " | ".join(parts)[:2000] # cap length
|
||||
|
||||
|
||||
async def _embed_batch(texts: list[str], client: httpx.AsyncClient) -> list[list[float]]:
|
||||
"""Get embeddings from Roadrunner's Ollama API."""
|
||||
resp = await client.post(
|
||||
EMBED_URL,
|
||||
json={"model": EMBED_MODEL, "input": texts},
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Ollama returns {"embeddings": [[...], ...]}
|
||||
return data.get("embeddings", [])
|
||||
|
||||
|
||||
def _simple_cluster(
|
||||
embeddings: list[list[float]],
|
||||
k: int = 3,
|
||||
max_iter: int = 20,
|
||||
) -> tuple[list[int], list[list[float]]]:
|
||||
"""Simple k-means clustering (no numpy dependency).
|
||||
|
||||
Returns (assignments, centroids).
|
||||
"""
|
||||
n = len(embeddings)
|
||||
if n <= k:
|
||||
return list(range(n)), embeddings[:]
|
||||
|
||||
# Init centroids: evenly spaced indices
|
||||
step = max(n // k, 1)
|
||||
centroids = [embeddings[i * step % n] for i in range(k)]
|
||||
assignments = [0] * n
|
||||
|
||||
for _ in range(max_iter):
|
||||
# Assign to nearest centroid
|
||||
new_assignments = []
|
||||
for emb in embeddings:
|
||||
dists = [_cosine_distance(emb, c) for c in centroids]
|
||||
new_assignments.append(dists.index(min(dists)))
|
||||
|
||||
if new_assignments == assignments:
|
||||
break
|
||||
assignments = new_assignments
|
||||
|
||||
# Recompute centroids
|
||||
for ci in range(k):
|
||||
members = [embeddings[j] for j in range(n) if assignments[j] == ci]
|
||||
if members:
|
||||
centroids[ci] = _mean_vector(members)
|
||||
|
||||
return assignments, centroids
|
||||
|
||||
|
||||
async def detect_anomalies(
|
||||
dataset_id: str,
|
||||
k: int = 3,
|
||||
outlier_threshold: float = 0.35,
|
||||
) -> list[dict]:
|
||||
"""Run embedding-based anomaly detection on a dataset.
|
||||
|
||||
1. Load rows 2. Embed via bge-m3 3. Cluster 4. Flag outliers.
|
||||
"""
|
||||
async with async_session_factory() as db:
|
||||
# Load rows
|
||||
result = await db.execute(
|
||||
select(DatasetRow.id, DatasetRow.row_index, DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(MAX_ROWS)
|
||||
)
|
||||
rows = result.all()
|
||||
if not rows:
|
||||
logger.info("No rows for anomaly detection in dataset %s", dataset_id)
|
||||
return []
|
||||
|
||||
row_ids = [r[0] for r in rows]
|
||||
row_indices = [r[1] for r in rows]
|
||||
texts = [_row_to_text(r[2]) for r in rows]
|
||||
|
||||
logger.info("Anomaly detection: %d rows, embedding with %s", len(texts), EMBED_MODEL)
|
||||
|
||||
# Embed in batches
|
||||
all_embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient() as client:
|
||||
for i in range(0, len(texts), BATCH_SIZE):
|
||||
batch = texts[i : i + BATCH_SIZE]
|
||||
try:
|
||||
embs = await _embed_batch(batch, client)
|
||||
all_embeddings.extend(embs)
|
||||
except Exception as e:
|
||||
logger.error("Embedding batch %d failed: %s", i, e)
|
||||
# Fill with zeros so indices stay aligned
|
||||
all_embeddings.extend([[0.0] * 1024] * len(batch))
|
||||
|
||||
if not all_embeddings or len(all_embeddings) != len(texts):
|
||||
logger.error("Embedding count mismatch")
|
||||
return []
|
||||
|
||||
# Cluster
|
||||
actual_k = min(k, len(all_embeddings))
|
||||
assignments, centroids = _simple_cluster(all_embeddings, k=actual_k)
|
||||
|
||||
# Compute distances from centroid
|
||||
anomalies: list[dict] = []
|
||||
for idx, (emb, cluster_id) in enumerate(zip(all_embeddings, assignments)):
|
||||
dist = _cosine_distance(emb, centroids[cluster_id])
|
||||
is_outlier = dist > outlier_threshold
|
||||
anomalies.append({
|
||||
"row_id": row_ids[idx],
|
||||
"row_index": row_indices[idx],
|
||||
"anomaly_score": round(dist, 4),
|
||||
"distance_from_centroid": round(dist, 4),
|
||||
"cluster_id": cluster_id,
|
||||
"is_outlier": is_outlier,
|
||||
})
|
||||
|
||||
# Save to DB
|
||||
outlier_count = 0
|
||||
for a in anomalies:
|
||||
ar = AnomalyResult(
|
||||
dataset_id=dataset_id,
|
||||
row_id=a["row_id"],
|
||||
anomaly_score=a["anomaly_score"],
|
||||
distance_from_centroid=a["distance_from_centroid"],
|
||||
cluster_id=a["cluster_id"],
|
||||
is_outlier=a["is_outlier"],
|
||||
)
|
||||
db.add(ar)
|
||||
if a["is_outlier"]:
|
||||
outlier_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"Anomaly detection complete: %d rows, %d outliers (threshold=%.2f)",
|
||||
len(anomalies), outlier_count, outlier_threshold,
|
||||
)
|
||||
|
||||
return sorted(anomalies, key=lambda x: x["anomaly_score"], reverse=True)
|
||||
81
backend/app/services/artifact_classifier.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Artifact classifier - identify Velociraptor artifact types from CSV headers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# (required_columns, artifact_type)
|
||||
FINGERPRINTS: list[tuple[set[str], str]] = [
|
||||
({"Pid", "Name", "CommandLine", "Exe"}, "Windows.System.Pslist"),
|
||||
({"Pid", "Name", "Ppid", "CommandLine"}, "Windows.System.Pslist"),
|
||||
({"Laddr.IP", "Raddr.IP", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"Laddr", "Raddr", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"FamilyString", "TypeString", "Status", "Pid"}, "Windows.Network.Netstat"),
|
||||
({"ServiceName", "DisplayName", "StartMode", "PathName"}, "Windows.System.Services"),
|
||||
({"DisplayName", "PathName", "ServiceDll", "StartMode"}, "Windows.System.Services"),
|
||||
({"OSPath", "Size", "Mtime", "Hash"}, "Windows.Search.FileFinder"),
|
||||
({"FullPath", "Size", "Mtime"}, "Windows.Search.FileFinder"),
|
||||
({"PrefetchFileName", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"Executable", "RunCount", "LastRunTimes"}, "Windows.Forensics.Prefetch"),
|
||||
({"KeyPath", "Type", "Data"}, "Windows.Registry.Finder"),
|
||||
({"Key", "Type", "Value"}, "Windows.Registry.Finder"),
|
||||
({"EventTime", "Channel", "EventID", "EventData"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"TimeCreated", "Channel", "EventID", "Provider"}, "Windows.EventLogs.EvtxHunter"),
|
||||
({"Entry", "Category", "Profile", "Launch String"}, "Windows.Sys.Autoruns"),
|
||||
({"Entry", "Category", "LaunchString"}, "Windows.Sys.Autoruns"),
|
||||
({"Name", "Record", "Type", "TTL"}, "Windows.Network.DNS"),
|
||||
({"QueryName", "QueryType", "QueryResults"}, "Windows.Network.DNS"),
|
||||
({"Path", "MD5", "SHA1", "SHA256"}, "Windows.Analysis.Hash"),
|
||||
({"Md5", "Sha256", "FullPath"}, "Windows.Analysis.Hash"),
|
||||
({"Name", "Actions", "NextRunTime", "Path"}, "Windows.System.TaskScheduler"),
|
||||
({"Name", "Uid", "Gid", "Description"}, "Windows.Sys.Users"),
|
||||
({"os_info.hostname", "os_info.system"}, "Server.Information.Client"),
|
||||
({"ClientId", "os_info.fqdn"}, "Server.Information.Client"),
|
||||
({"Pid", "Name", "Cmdline", "Exe"}, "Linux.Sys.Pslist"),
|
||||
({"Laddr", "Raddr", "Status", "FamilyString"}, "Linux.Network.Netstat"),
|
||||
({"Namespace", "ClassName", "PropertyName"}, "Windows.System.WMI"),
|
||||
({"RemoteAddress", "RemoteMACAddress", "InterfaceAlias"}, "Windows.Network.ArpCache"),
|
||||
({"URL", "Title", "VisitCount", "LastVisitTime"}, "Windows.Applications.BrowserHistory"),
|
||||
({"Url", "Title", "Visits"}, "Windows.Applications.BrowserHistory"),
|
||||
]
|
||||
|
||||
VELOCIRAPTOR_META = {"_Source", "ClientId", "FlowId", "Fqdn", "HuntId"}
|
||||
|
||||
CATEGORY_MAP = {
|
||||
"Pslist": "process",
|
||||
"Netstat": "network",
|
||||
"Services": "persistence",
|
||||
"FileFinder": "filesystem",
|
||||
"Prefetch": "execution",
|
||||
"Registry": "persistence",
|
||||
"EvtxHunter": "eventlog",
|
||||
"EventLogs": "eventlog",
|
||||
"Autoruns": "persistence",
|
||||
"DNS": "network",
|
||||
"Hash": "filesystem",
|
||||
"TaskScheduler": "persistence",
|
||||
"Users": "account",
|
||||
"Client": "system",
|
||||
"WMI": "persistence",
|
||||
"ArpCache": "network",
|
||||
"BrowserHistory": "application",
|
||||
}
|
||||
|
||||
|
||||
def classify_artifact(columns: list[str]) -> str:
|
||||
col_set = set(columns)
|
||||
for required, artifact_type in FINGERPRINTS:
|
||||
if required.issubset(col_set):
|
||||
return artifact_type
|
||||
if VELOCIRAPTOR_META.intersection(col_set):
|
||||
return "Velociraptor.Unknown"
|
||||
return "Unknown"
|
||||
|
||||
|
||||
def get_artifact_category(artifact_type: str) -> str:
|
||||
for key, category in CATEGORY_MAP.items():
|
||||
if key.lower() in artifact_type.lower():
|
||||
return category
|
||||
return "unknown"
|
||||
201
backend/app/services/auth.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Authentication & security — JWT tokens, password hashing, role-based access.
|
||||
|
||||
Provides:
|
||||
- Password hashing (bcrypt via passlib)
|
||||
- JWT access/refresh token creation and verification
|
||||
- FastAPI dependency for protecting routes
|
||||
- Role-based enforcement (analyst, admin, viewer)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ─────────────────────────────────────────────────
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
# ── JWT tokens ────────────────────────────────────────────────────────
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # user_id
|
||||
role: str
|
||||
exp: datetime
|
||||
type: str # "access" or "refresh"
|
||||
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_DAYS
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "refresh",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_token_pair(user_id: str, role: str) -> TokenPair:
|
||||
return TokenPair(
|
||||
access_token=create_access_token(user_id, role),
|
||||
refresh_token=create_refresh_token(user_id, role),
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
"""Decode and validate a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[ALGORITHM])
|
||||
return TokenPayload(**payload)
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid token: {e}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
# ── FastAPI dependencies ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract and validate the current user from JWT.
|
||||
|
||||
When AUTH is disabled (no JWT secret configured), returns a default analyst user.
|
||||
"""
|
||||
# If auth is disabled (dev mode), return a default user
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token_data = decode_token(credentials.credentials)
|
||||
|
||||
if token_data.type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use access token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Optional[User]:
|
||||
"""Like get_current_user, but returns None instead of raising if no token."""
|
||||
if not credentials:
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def require_role(*roles: str):
|
||||
"""Dependency factory that requires the current user to have one of the specified roles."""
|
||||
|
||||
async def _check(user: User = Depends(get_current_user)) -> User:
|
||||
if user.role not in roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires one of roles: {', '.join(roles)}. You have: {user.role}",
|
||||
)
|
||||
return user
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
# Convenience dependencies
|
||||
require_analyst = require_role("analyst", "admin")
|
||||
require_admin = require_role("admin")
|
||||
400
backend/app/services/correlation.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Cross-hunt correlation engine — find IOC overlaps, timeline patterns, and shared TTPs.
|
||||
|
||||
Identifies connections between hunts by analyzing:
|
||||
1. Shared IOC values across datasets
|
||||
2. Overlapping time ranges and temporal proximity
|
||||
3. Common MITRE ATT&CK techniques across hypotheses
|
||||
4. Host-to-host lateral movement patterns
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow, Hunt, Hypothesis, EnrichmentResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IOCOverlap:
|
||||
"""Shared IOC between two or more hunts/datasets."""
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
datasets: list[dict] = field(default_factory=list) # [{dataset_id, hunt_id, name}]
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
count: int = 0
|
||||
enrichment_verdict: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeOverlap:
|
||||
"""Overlapping time window between datasets."""
|
||||
dataset_a: dict = field(default_factory=dict)
|
||||
dataset_b: dict = field(default_factory=dict)
|
||||
overlap_start: str = ""
|
||||
overlap_end: str = ""
|
||||
overlap_hours: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechniqueOverlap:
|
||||
"""Shared MITRE ATT&CK technique across hunts."""
|
||||
technique_id: str
|
||||
technique_name: str = ""
|
||||
hypotheses: list[dict] = field(default_factory=list)
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrelationResult:
|
||||
"""Complete correlation analysis result."""
|
||||
hunt_ids: list[str]
|
||||
ioc_overlaps: list[IOCOverlap] = field(default_factory=list)
|
||||
time_overlaps: list[TimeOverlap] = field(default_factory=list)
|
||||
technique_overlaps: list[TechniqueOverlap] = field(default_factory=list)
|
||||
host_overlaps: list[dict] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
total_correlations: int = 0
|
||||
|
||||
|
||||
class CorrelationEngine:
|
||||
"""Engine for finding correlations across hunts and datasets."""
|
||||
|
||||
async def correlate_hunts(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> CorrelationResult:
|
||||
"""Run full correlation analysis across specified hunts."""
|
||||
result = CorrelationResult(hunt_ids=hunt_ids)
|
||||
|
||||
# Run all correlation types
|
||||
result.ioc_overlaps = await self._find_ioc_overlaps(hunt_ids, db)
|
||||
result.time_overlaps = await self._find_time_overlaps(hunt_ids, db)
|
||||
result.technique_overlaps = await self._find_technique_overlaps(hunt_ids, db)
|
||||
result.host_overlaps = await self._find_host_overlaps(hunt_ids, db)
|
||||
|
||||
result.total_correlations = (
|
||||
len(result.ioc_overlaps)
|
||||
+ len(result.time_overlaps)
|
||||
+ len(result.technique_overlaps)
|
||||
+ len(result.host_overlaps)
|
||||
)
|
||||
|
||||
result.summary = self._build_summary(result)
|
||||
return result
|
||||
|
||||
async def correlate_all(self, db: AsyncSession) -> CorrelationResult:
|
||||
"""Correlate across ALL hunts in the system."""
|
||||
stmt = select(Hunt.id)
|
||||
result = await db.execute(stmt)
|
||||
hunt_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if len(hunt_ids) < 2:
|
||||
return CorrelationResult(
|
||||
hunt_ids=hunt_ids,
|
||||
summary="Need at least 2 hunts for correlation analysis.",
|
||||
)
|
||||
|
||||
return await self.correlate_hunts(hunt_ids, db)
|
||||
|
||||
async def find_ioc_across_hunts(
|
||||
self,
|
||||
ioc_value: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find all occurrences of a specific IOC across all datasets/hunts."""
|
||||
# Search in dataset rows using JSON contains
|
||||
stmt = select(DatasetRow, Dataset).join(
|
||||
Dataset, DatasetRow.dataset_id == Dataset.id
|
||||
)
|
||||
result = await db.execute(stmt.limit(5000))
|
||||
rows = result.all()
|
||||
|
||||
occurrences = []
|
||||
for row, dataset in rows:
|
||||
data = row.data or {}
|
||||
normalized = row.normalized_data or {}
|
||||
|
||||
# Search both raw and normalized data
|
||||
for col, val in {**data, **normalized}.items():
|
||||
if str(val) == ioc_value:
|
||||
occurrences.append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"row_index": row.row_index,
|
||||
"column": col,
|
||||
})
|
||||
break
|
||||
|
||||
return occurrences
|
||||
|
||||
# ── IOC overlap detection ─────────────────────────────────────────
|
||||
|
||||
async def _find_ioc_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[IOCOverlap]:
|
||||
"""Find IOC values that appear in datasets from different hunts."""
|
||||
# Get all datasets for the specified hunts
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
if len(datasets) < 2:
|
||||
return []
|
||||
|
||||
# Build IOC → dataset mapping
|
||||
ioc_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
if not dataset.ioc_columns:
|
||||
continue
|
||||
|
||||
ioc_cols = list(dataset.ioc_columns.keys())
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in ioc_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
ioc_map[str(val).strip()].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"column": col,
|
||||
"ioc_type": dataset.ioc_columns.get(col, "unknown"),
|
||||
})
|
||||
|
||||
# Filter to IOCs appearing in multiple hunts
|
||||
overlaps = []
|
||||
for ioc_value, appearances in ioc_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
# Check for enrichment data
|
||||
enrich_stmt = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value == ioc_value
|
||||
).limit(1)
|
||||
enrich_result = await db.execute(enrich_stmt)
|
||||
enrichment = enrich_result.scalar_one_or_none()
|
||||
|
||||
overlaps.append(IOCOverlap(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=appearances[0].get("ioc_type", "unknown"),
|
||||
datasets=appearances,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
count=len(appearances),
|
||||
enrichment_verdict=enrichment.verdict if enrichment else "",
|
||||
))
|
||||
|
||||
# Sort by count descending
|
||||
overlaps.sort(key=lambda x: x.count, reverse=True)
|
||||
return overlaps[:100] # Limit results
|
||||
|
||||
# ── Time window overlap ───────────────────────────────────────────
|
||||
|
||||
async def _find_time_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TimeOverlap]:
|
||||
"""Find datasets across hunts with overlapping time ranges."""
|
||||
stmt = select(Dataset).where(
|
||||
Dataset.hunt_id.in_(hunt_ids),
|
||||
Dataset.time_range_start.isnot(None),
|
||||
Dataset.time_range_end.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
overlaps = []
|
||||
for i, ds_a in enumerate(datasets):
|
||||
for ds_b in datasets[i + 1:]:
|
||||
if ds_a.hunt_id == ds_b.hunt_id:
|
||||
continue # Same hunt, skip
|
||||
|
||||
try:
|
||||
a_start = datetime.fromisoformat(ds_a.time_range_start)
|
||||
a_end = datetime.fromisoformat(ds_a.time_range_end)
|
||||
b_start = datetime.fromisoformat(ds_b.time_range_start)
|
||||
b_end = datetime.fromisoformat(ds_b.time_range_end)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Check overlap
|
||||
overlap_start = max(a_start, b_start)
|
||||
overlap_end = min(a_end, b_end)
|
||||
|
||||
if overlap_start < overlap_end:
|
||||
hours = (overlap_end - overlap_start).total_seconds() / 3600
|
||||
overlaps.append(TimeOverlap(
|
||||
dataset_a={
|
||||
"id": ds_a.id,
|
||||
"name": ds_a.name,
|
||||
"hunt_id": ds_a.hunt_id,
|
||||
"start": ds_a.time_range_start,
|
||||
"end": ds_a.time_range_end,
|
||||
},
|
||||
dataset_b={
|
||||
"id": ds_b.id,
|
||||
"name": ds_b.name,
|
||||
"hunt_id": ds_b.hunt_id,
|
||||
"start": ds_b.time_range_start,
|
||||
"end": ds_b.time_range_end,
|
||||
},
|
||||
overlap_start=overlap_start.isoformat(),
|
||||
overlap_end=overlap_end.isoformat(),
|
||||
overlap_hours=round(hours, 2),
|
||||
))
|
||||
|
||||
overlaps.sort(key=lambda x: x.overlap_hours, reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── MITRE technique overlap ───────────────────────────────────────
|
||||
|
||||
async def _find_technique_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TechniqueOverlap]:
|
||||
"""Find MITRE ATT&CK techniques shared across hunts."""
|
||||
stmt = select(Hypothesis).where(
|
||||
Hypothesis.hunt_id.in_(hunt_ids),
|
||||
Hypothesis.mitre_technique.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
hypotheses = result.scalars().all()
|
||||
|
||||
technique_map: dict[str, list[dict]] = defaultdict(list)
|
||||
for hyp in hypotheses:
|
||||
technique = hyp.mitre_technique.strip()
|
||||
if technique:
|
||||
technique_map[technique].append({
|
||||
"hypothesis_id": hyp.id,
|
||||
"hypothesis_title": hyp.title,
|
||||
"hunt_id": hyp.hunt_id,
|
||||
"status": hyp.status,
|
||||
})
|
||||
|
||||
overlaps = []
|
||||
for technique, hyps in technique_map.items():
|
||||
hunt_set = set(h["hunt_id"] for h in hyps if h["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append(TechniqueOverlap(
|
||||
technique_id=technique,
|
||||
hypotheses=hyps,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
))
|
||||
|
||||
return overlaps
|
||||
|
||||
# ── Host overlap ──────────────────────────────────────────────────
|
||||
|
||||
async def _find_host_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find hostnames that appear in datasets from different hunts.
|
||||
|
||||
Useful for detecting lateral movement patterns.
|
||||
"""
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
norm_cols = dataset.normalized_columns or {}
|
||||
# Look for hostname columns
|
||||
hostname_cols = [
|
||||
orig for orig, canon in norm_cols.items()
|
||||
if canon in ("hostname", "host", "computer_name", "src_host", "dst_host")
|
||||
]
|
||||
if not hostname_cols:
|
||||
continue
|
||||
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in hostname_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
host_name = str(val).strip().upper()
|
||||
host_map[host_name].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
})
|
||||
|
||||
# Filter to hosts appearing in multiple hunts
|
||||
overlaps = []
|
||||
for host, appearances in host_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append({
|
||||
"hostname": host,
|
||||
"hunt_ids": sorted(hunt_set),
|
||||
"dataset_count": len(appearances),
|
||||
"datasets": appearances[:10],
|
||||
})
|
||||
|
||||
overlaps.sort(key=lambda x: x["dataset_count"], reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── Summary builder ───────────────────────────────────────────────
|
||||
|
||||
def _build_summary(self, result: CorrelationResult) -> str:
|
||||
"""Build a human-readable summary of correlations."""
|
||||
parts = [f"Correlation analysis across {len(result.hunt_ids)} hunts:"]
|
||||
|
||||
if result.ioc_overlaps:
|
||||
malicious = [o for o in result.ioc_overlaps if o.enrichment_verdict == "malicious"]
|
||||
parts.append(
|
||||
f" - {len(result.ioc_overlaps)} shared IOCs "
|
||||
f"({len(malicious)} flagged malicious)"
|
||||
)
|
||||
else:
|
||||
parts.append(" - No shared IOCs found")
|
||||
|
||||
if result.time_overlaps:
|
||||
parts.append(f" - {len(result.time_overlaps)} overlapping time windows")
|
||||
|
||||
if result.technique_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.technique_overlaps)} shared MITRE techniques"
|
||||
)
|
||||
|
||||
if result.host_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.host_overlaps)} hosts appearing in multiple hunts "
|
||||
"(potential lateral movement)"
|
||||
)
|
||||
|
||||
if result.total_correlations == 0:
|
||||
parts.append(" No significant correlations detected.")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# Singleton
|
||||
correlation_engine = CorrelationEngine()
|
||||
165
backend/app/services/csv_parser.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""CSV parsing engine with encoding detection, delimiter sniffing, and streaming.
|
||||
|
||||
Handles large Velociraptor CSV exports with resilience to encoding issues,
|
||||
varied delimiters, and malformed rows.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
import chardet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reasonable defaults
|
||||
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB per field
|
||||
csv.field_size_limit(MAX_FIELD_SIZE)
|
||||
|
||||
|
||||
def detect_encoding(file_bytes: bytes, sample_size: int = 65536) -> str:
|
||||
"""Detect file encoding from a sample of bytes."""
|
||||
result = chardet.detect(file_bytes[:sample_size])
|
||||
encoding = result.get("encoding", "utf-8") or "utf-8"
|
||||
confidence = result.get("confidence", 0)
|
||||
logger.info(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||
# Fall back to utf-8 if confidence is very low
|
||||
if confidence < 0.5:
|
||||
encoding = "utf-8"
|
||||
return encoding
|
||||
|
||||
|
||||
def detect_delimiter(text_sample: str) -> str:
|
||||
"""Sniff the CSV delimiter from a text sample."""
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(text_sample, delimiters=",\t;|")
|
||||
return dialect.delimiter
|
||||
except csv.Error:
|
||||
return ","
|
||||
|
||||
|
||||
def infer_column_types(rows: list[dict], sample_size: int = 100) -> dict[str, str]:
|
||||
"""Infer column types from a sample of rows.
|
||||
|
||||
Returns a mapping of column_name → type_hint where type_hint is one of:
|
||||
timestamp, integer, float, ip, hash_md5, hash_sha1, hash_sha256, domain, path, string
|
||||
"""
|
||||
import re
|
||||
|
||||
type_map: dict[str, dict[str, int]] = {}
|
||||
sample = rows[:sample_size]
|
||||
|
||||
patterns = {
|
||||
"ip": re.compile(
|
||||
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
|
||||
),
|
||||
"hash_md5": re.compile(r"^[a-fA-F0-9]{32}$"),
|
||||
"hash_sha1": re.compile(r"^[a-fA-F0-9]{40}$"),
|
||||
"hash_sha256": re.compile(r"^[a-fA-F0-9]{64}$"),
|
||||
"integer": re.compile(r"^-?\d+$"),
|
||||
"float": re.compile(r"^-?\d+\.\d+$"),
|
||||
"timestamp": re.compile(
|
||||
r"^\d{4}[-/]\d{2}[-/]\d{2}[T ]\d{2}:\d{2}"
|
||||
),
|
||||
"domain": re.compile(
|
||||
r"^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?(\.[a-zA-Z]{2,})+$"
|
||||
),
|
||||
"path": re.compile(r"^([A-Z]:\\|/)", re.IGNORECASE),
|
||||
}
|
||||
|
||||
for row in sample:
|
||||
for col, val in row.items():
|
||||
if col not in type_map:
|
||||
type_map[col] = {}
|
||||
val_str = str(val).strip()
|
||||
if not val_str:
|
||||
continue
|
||||
matched = False
|
||||
for type_name, pattern in patterns.items():
|
||||
if pattern.match(val_str):
|
||||
type_map[col][type_name] = type_map[col].get(type_name, 0) + 1
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
type_map[col]["string"] = type_map[col].get("string", 0) + 1
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for col, counts in type_map.items():
|
||||
if counts:
|
||||
result[col] = max(counts, key=counts.get) # type: ignore[arg-type]
|
||||
else:
|
||||
result[col] = "string"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_csv_bytes(
|
||||
raw_bytes: bytes,
|
||||
max_rows: int | None = None,
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""Parse a CSV file from raw bytes.
|
||||
|
||||
Returns:
|
||||
(rows, metadata) where metadata contains encoding, delimiter, columns, etc.
|
||||
"""
|
||||
encoding = detect_encoding(raw_bytes)
|
||||
|
||||
try:
|
||||
text = raw_bytes.decode(encoding, errors="replace")
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
text = raw_bytes.decode("utf-8", errors="replace")
|
||||
encoding = "utf-8"
|
||||
|
||||
# Detect delimiter from first few KB
|
||||
delimiter = detect_delimiter(text[:8192])
|
||||
|
||||
reader = csv.DictReader(io.StringIO(text), delimiter=delimiter)
|
||||
columns = reader.fieldnames or []
|
||||
|
||||
rows: list[dict] = []
|
||||
for i, row in enumerate(reader):
|
||||
if max_rows is not None and i >= max_rows:
|
||||
break
|
||||
rows.append(dict(row))
|
||||
|
||||
column_types = infer_column_types(rows) if rows else {}
|
||||
|
||||
metadata = {
|
||||
"encoding": encoding,
|
||||
"delimiter": delimiter,
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
"row_count": len(rows),
|
||||
"total_rows_in_file": len(rows), # same when no max_rows
|
||||
}
|
||||
|
||||
return rows, metadata
|
||||
|
||||
|
||||
async def parse_csv_streaming(
|
||||
file_path: Path,
|
||||
chunk_size: int = 8192,
|
||||
) -> AsyncIterator[tuple[int, dict]]:
|
||||
"""Stream-parse a CSV file yielding (row_index, row_dict) tuples.
|
||||
|
||||
Memory-efficient for large files.
|
||||
"""
|
||||
import aiofiles # type: ignore[import-untyped]
|
||||
|
||||
# Read a sample for encoding/delimiter detection
|
||||
with open(file_path, "rb") as f:
|
||||
sample_bytes = f.read(65536)
|
||||
|
||||
encoding = detect_encoding(sample_bytes)
|
||||
text_sample = sample_bytes.decode(encoding, errors="replace")
|
||||
delimiter = detect_delimiter(text_sample[:8192])
|
||||
|
||||
# Now stream-read
|
||||
async with aiofiles.open(file_path, mode="r", encoding=encoding, errors="replace") as f:
|
||||
content = await f.read() # For DictReader compatibility
|
||||
|
||||
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
yield i, dict(row)
|
||||
238
backend/app/services/data_query.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Natural-language data query service with SSE streaming.
|
||||
|
||||
Lets analysts ask questions about dataset rows in plain English.
|
||||
Routes to fast model (Roadrunner) for quick queries, heavy model (Wile)
|
||||
for deep analysis. Supports streaming via OllamaProvider.generate_stream().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum rows to include in context window
|
||||
MAX_CONTEXT_ROWS = 60
|
||||
MAX_ROW_TEXT_CHARS = 300
|
||||
|
||||
|
||||
def _rows_to_text(rows: list[dict], columns: list[str]) -> str:
|
||||
"""Convert dataset rows to a compact text table for the LLM context."""
|
||||
if not rows:
|
||||
return "(no rows)"
|
||||
# Header
|
||||
header = " | ".join(columns[:20]) # cap columns to avoid overflow
|
||||
lines = [header, "-" * min(len(header), 120)]
|
||||
for row in rows[:MAX_CONTEXT_ROWS]:
|
||||
vals = []
|
||||
for c in columns[:20]:
|
||||
v = str(row.get(c, ""))
|
||||
if len(v) > 80:
|
||||
v = v[:77] + "..."
|
||||
vals.append(v)
|
||||
line = " | ".join(vals)
|
||||
if len(line) > MAX_ROW_TEXT_CHARS:
|
||||
line = line[:MAX_ROW_TEXT_CHARS] + "..."
|
||||
lines.append(line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
QUERY_SYSTEM_PROMPT = """You are a cybersecurity data analyst assistant for ThreatHunt.
|
||||
You have been given a sample of rows from a forensic artifact dataset (Velociraptor, etc.).
|
||||
|
||||
Your job:
|
||||
- Answer the analyst's question about this data accurately and concisely
|
||||
- Point out suspicious patterns, anomalies, or indicators of compromise
|
||||
- Reference MITRE ATT&CK techniques when relevant
|
||||
- Suggest follow-up queries or pivots
|
||||
- If you cannot answer from the data provided, say so clearly
|
||||
|
||||
Rules:
|
||||
- Be factual - only reference data you can see
|
||||
- Use forensic terminology appropriate for SOC/DFIR analysts
|
||||
- Format your answer with clear sections using markdown
|
||||
- If the data seems benign, say so - do not fabricate threats"""
|
||||
|
||||
|
||||
async def _load_dataset_context(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
sample_size: int = MAX_CONTEXT_ROWS,
|
||||
) -> tuple[dict, str, int]:
|
||||
"""Load dataset metadata + sample rows for context.
|
||||
|
||||
Returns (metadata_dict, rows_text, total_row_count).
|
||||
"""
|
||||
ds = await db.get(Dataset, dataset_id)
|
||||
if not ds:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
|
||||
# Get total count
|
||||
count_q = await db.execute(
|
||||
select(func.count()).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
total = count_q.scalar() or 0
|
||||
|
||||
# Sample rows - get first batch + some from the middle
|
||||
half = sample_size // 2
|
||||
result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(half)
|
||||
)
|
||||
first_rows = result.scalars().all()
|
||||
|
||||
# If dataset is large, also sample from the middle
|
||||
middle_rows = []
|
||||
if total > sample_size:
|
||||
mid_offset = total // 2
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(mid_offset)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
else:
|
||||
result2 = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(half)
|
||||
.limit(sample_size - half)
|
||||
)
|
||||
middle_rows = result2.scalars().all()
|
||||
|
||||
all_rows = first_rows + middle_rows
|
||||
row_dicts = [r.data if isinstance(r.data, dict) else {} for r in all_rows]
|
||||
|
||||
columns = list(ds.column_schema.keys()) if ds.column_schema else []
|
||||
if not columns and row_dicts:
|
||||
columns = list(row_dicts[0].keys())
|
||||
|
||||
rows_text = _rows_to_text(row_dicts, columns)
|
||||
|
||||
metadata = {
|
||||
"name": ds.name,
|
||||
"filename": ds.filename,
|
||||
"source_tool": ds.source_tool,
|
||||
"artifact_type": getattr(ds, "artifact_type", None),
|
||||
"row_count": total,
|
||||
"columns": columns[:30],
|
||||
"sample_rows_shown": len(all_rows),
|
||||
}
|
||||
return metadata, rows_text, total
|
||||
|
||||
|
||||
async def query_dataset(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> str:
|
||||
"""Non-streaming query: returns full answer text."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
)
|
||||
return result.get("response", "No response generated.")
|
||||
|
||||
|
||||
async def query_dataset_stream(
|
||||
dataset_id: str,
|
||||
question: str,
|
||||
mode: str = "quick",
|
||||
) -> AsyncIterator[str]:
|
||||
"""Streaming query: yields SSE-formatted events."""
|
||||
from app.agents.providers_v2 import OllamaProvider, Node
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
# Send initial metadata event
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': 'Loading dataset...'})}\n\n"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
meta, rows_text, total = await _load_dataset_context(dataset_id, db)
|
||||
|
||||
yield f"data: {json.dumps({'type': 'metadata', 'dataset': meta})}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'status', 'message': f'Querying LLM ({mode} mode)...'})}\n\n"
|
||||
|
||||
prompt = _build_prompt(question, meta, rows_text, total)
|
||||
|
||||
if mode == "deep":
|
||||
provider = OllamaProvider(settings.DEFAULT_HEAVY_MODEL, Node.WILE)
|
||||
max_tokens = 4096
|
||||
model_name = settings.DEFAULT_HEAVY_MODEL
|
||||
node_name = "wile"
|
||||
else:
|
||||
provider = OllamaProvider(settings.DEFAULT_FAST_MODEL, Node.ROADRUNNER)
|
||||
max_tokens = 2048
|
||||
model_name = settings.DEFAULT_FAST_MODEL
|
||||
node_name = "roadrunner"
|
||||
|
||||
# Stream tokens
|
||||
token_count = 0
|
||||
try:
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=QUERY_SYSTEM_PROMPT,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.3,
|
||||
):
|
||||
token_count += 1
|
||||
yield f"data: {json.dumps({'type': 'token', 'content': token})}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {e}")
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
yield f"data: {json.dumps({'type': 'done', 'tokens': token_count, 'elapsed_ms': elapsed_ms, 'model': model_name, 'node': node_name})}\n\n"
|
||||
|
||||
|
||||
def _build_prompt(question: str, meta: dict, rows_text: str, total: int) -> str:
|
||||
"""Construct the full prompt with data context."""
|
||||
parts = [
|
||||
f"## Dataset: {meta['name']}",
|
||||
f"- Source: {meta.get('source_tool', 'unknown')}",
|
||||
f"- Artifact type: {meta.get('artifact_type', 'unknown')}",
|
||||
f"- Total rows: {total}",
|
||||
f"- Columns: {', '.join(meta.get('columns', []))}",
|
||||
f"- Showing {meta['sample_rows_shown']} sample rows below",
|
||||
"",
|
||||
"## Sample Data",
|
||||
"```",
|
||||
rows_text,
|
||||
"```",
|
||||
"",
|
||||
f"## Analyst Question",
|
||||
question,
|
||||
]
|
||||
return "\n".join(parts)
|
||||
655
backend/app/services/enrichment.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""IOC Enrichment Engine — VirusTotal, AbuseIPDB, Shodan integrations.
|
||||
|
||||
Provides automated IOC enrichment with caching and rate limiting.
|
||||
Enriches IPs, hashes, domains with threat intelligence verdicts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db.models import EnrichmentResult as EnrichmentDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IOCType(str, Enum):
|
||||
IP = "ip"
|
||||
DOMAIN = "domain"
|
||||
HASH_MD5 = "hash_md5"
|
||||
HASH_SHA1 = "hash_sha1"
|
||||
HASH_SHA256 = "hash_sha256"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class Verdict(str, Enum):
|
||||
CLEAN = "clean"
|
||||
SUSPICIOUS = "suspicious"
|
||||
MALICIOUS = "malicious"
|
||||
UNKNOWN = "unknown"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichmentResultData:
|
||||
"""Enrichment result from a provider."""
|
||||
ioc_value: str
|
||||
ioc_type: IOCType
|
||||
source: str
|
||||
verdict: Verdict
|
||||
score: float = 0.0 # 0-100 normalized threat score
|
||||
raw_data: dict = field(default_factory=dict)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple token bucket rate limiter for API calls."""
|
||||
|
||||
def __init__(self, calls_per_minute: int = 4):
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.interval = 60.0 / calls_per_minute
|
||||
self._last_call: float = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_call
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
self._last_call = time.monotonic()
|
||||
|
||||
|
||||
# ── Provider base ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentProvider:
|
||||
"""Base class for enrichment providers."""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, api_key: str = "", rate_limit: int = 4):
|
||||
self.api_key = api_key
|
||||
self.rate_limiter = RateLimiter(rate_limit)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=5),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def cleanup(self):
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.api_key)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── VirusTotal ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class VirusTotalProvider(EnrichmentProvider):
|
||||
"""VirusTotal v3 API provider."""
|
||||
|
||||
name = "virustotal"
|
||||
BASE_URL = "https://www.virustotal.com/api/v3"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.VIRUSTOTAL_API_KEY, rate_limit=4)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {"x-apikey": self.api_key}
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="VirusTotal API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
endpoint = self._get_endpoint(ioc_value, ioc_type)
|
||||
if not endpoint:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"Unsupported IOC type: {ioc_type}",
|
||||
)
|
||||
|
||||
client = self._get_client()
|
||||
resp = await client.get(endpoint, headers=self._headers())
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
attrs = data.get("data", {}).get("attributes", {})
|
||||
stats = attrs.get("last_analysis_stats", {})
|
||||
|
||||
malicious = stats.get("malicious", 0)
|
||||
suspicious = stats.get("suspicious", 0)
|
||||
total = sum(stats.values()) if stats else 0
|
||||
|
||||
# Determine verdict
|
||||
if malicious > 3:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif malicious > 0 or suspicious > 2:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total > 0:
|
||||
verdict = Verdict.CLEAN
|
||||
else:
|
||||
verdict = Verdict.UNKNOWN
|
||||
|
||||
score = (malicious / total * 100) if total > 0 else 0
|
||||
|
||||
tags = attrs.get("tags", [])
|
||||
if attrs.get("type_description"):
|
||||
tags.append(attrs["type_description"])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=round(score, 1),
|
||||
raw_data={
|
||||
"stats": stats,
|
||||
"reputation": attrs.get("reputation", 0),
|
||||
"type_description": attrs.get("type_description", ""),
|
||||
"names": attrs.get("names", [])[:5],
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=attrs.get("country", ""),
|
||||
asn=str(attrs.get("asn", "")),
|
||||
org=attrs.get("as_owner", ""),
|
||||
last_seen=attrs.get("last_analysis_date", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"HTTP {e.response.status_code}",
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
def _get_endpoint(self, ioc_value: str, ioc_type: IOCType) -> str | None:
|
||||
if ioc_type == IOCType.IP:
|
||||
return f"{self.BASE_URL}/ip_addresses/{ioc_value}"
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
return f"{self.BASE_URL}/domains/{ioc_value}"
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
return f"{self.BASE_URL}/files/{ioc_value}"
|
||||
elif ioc_type == IOCType.URL:
|
||||
url_id = hashlib.sha256(ioc_value.encode()).hexdigest()
|
||||
return f"{self.BASE_URL}/urls/{url_id}"
|
||||
return None
|
||||
|
||||
|
||||
# ── AbuseIPDB ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AbuseIPDBProvider(EnrichmentProvider):
|
||||
"""AbuseIPDB API provider — IP reputation."""
|
||||
|
||||
name = "abuseipdb"
|
||||
BASE_URL = "https://api.abuseipdb.com/api/v2"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.ABUSEIPDB_API_KEY, rate_limit=10)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/check",
|
||||
params={"ipAddress": ioc_value, "maxAgeInDays": 90, "verbose": "true"},
|
||||
headers={"Key": self.api_key, "Accept": "application/json"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", {})
|
||||
|
||||
abuse_score = data.get("abuseConfidenceScore", 0)
|
||||
total_reports = data.get("totalReports", 0)
|
||||
|
||||
if abuse_score >= 75:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif abuse_score >= 25 or total_reports > 5:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total_reports == 0:
|
||||
verdict = Verdict.UNKNOWN
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
|
||||
categories = data.get("reports", [])
|
||||
tags = []
|
||||
for report in categories[:10]:
|
||||
for cat_id in report.get("categories", []):
|
||||
tag = self._category_name(cat_id)
|
||||
if tag and tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=float(abuse_score),
|
||||
raw_data={
|
||||
"abuse_confidence_score": abuse_score,
|
||||
"total_reports": total_reports,
|
||||
"is_whitelisted": data.get("isWhitelisted"),
|
||||
"is_tor": data.get("isTor", False),
|
||||
"usage_type": data.get("usageType", ""),
|
||||
"isp": data.get("isp", ""),
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=data.get("countryCode", ""),
|
||||
org=data.get("isp", ""),
|
||||
last_seen=data.get("lastReportedAt", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _category_name(cat_id: int) -> str:
|
||||
categories = {
|
||||
1: "DNS Compromise", 2: "DNS Poisoning", 3: "Fraud Orders",
|
||||
4: "DDoS Attack", 5: "FTP Brute-Force", 6: "Ping of Death",
|
||||
7: "Phishing", 8: "Fraud VoIP", 9: "Open Proxy",
|
||||
10: "Web Spam", 11: "Email Spam", 12: "Blog Spam",
|
||||
13: "VPN IP", 14: "Port Scan", 15: "Hacking",
|
||||
16: "SQL Injection", 17: "Spoofing", 18: "Brute-Force",
|
||||
19: "Bad Web Bot", 20: "Exploited Host", 21: "Web App Attack",
|
||||
22: "SSH", 23: "IoT Targeted",
|
||||
}
|
||||
return categories.get(cat_id, "")
|
||||
|
||||
|
||||
# ── Shodan ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ShodanProvider(EnrichmentProvider):
|
||||
"""Shodan API provider — infrastructure intelligence."""
|
||||
|
||||
name = "shodan"
|
||||
BASE_URL = "https://api.shodan.io"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.SHODAN_API_KEY, rate_limit=1)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/shodan/host/{ioc_value}",
|
||||
params={"key": self.api_key, "minify": "true"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
ports = data.get("ports", [])
|
||||
vulns = data.get("vulns", [])
|
||||
tags_raw = data.get("tags", [])
|
||||
|
||||
# Determine verdict based on open ports and vulns
|
||||
if vulns:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = min(len(vulns) * 15, 100.0)
|
||||
elif len(ports) > 20:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = 40.0
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
score = 0.0
|
||||
|
||||
tags = tags_raw[:10]
|
||||
if vulns:
|
||||
tags.extend([f"CVE: {v}" for v in vulns[:5]])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=score,
|
||||
raw_data={
|
||||
"ports": ports[:20],
|
||||
"vulns": vulns[:10],
|
||||
"os": data.get("os"),
|
||||
"hostnames": data.get("hostnames", [])[:5],
|
||||
"domains": data.get("domains", [])[:5],
|
||||
"last_update": data.get("last_update", ""),
|
||||
},
|
||||
tags=tags[:15],
|
||||
country=data.get("country_code", ""),
|
||||
asn=data.get("asn", ""),
|
||||
org=data.get("org", ""),
|
||||
last_seen=data.get("last_update", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Engine (orchestrator) ──────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentEngine:
|
||||
"""Orchestrates IOC enrichment across all providers with caching."""
|
||||
|
||||
CACHE_TTL_HOURS = 24
|
||||
|
||||
def __init__(self):
|
||||
self.providers: list[EnrichmentProvider] = [
|
||||
VirusTotalProvider(),
|
||||
AbuseIPDBProvider(),
|
||||
ShodanProvider(),
|
||||
]
|
||||
|
||||
@property
|
||||
def configured_providers(self) -> list[EnrichmentProvider]:
|
||||
return [p for p in self.providers if p.is_configured]
|
||||
|
||||
async def enrich_ioc(
|
||||
self,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
db: AsyncSession | None = None,
|
||||
skip_cache: bool = False,
|
||||
) -> list[EnrichmentResultData]:
|
||||
"""Enrich a single IOC across all configured providers.
|
||||
|
||||
Uses cached results from DB when available.
|
||||
"""
|
||||
results: list[EnrichmentResultData] = []
|
||||
|
||||
# Check cache first
|
||||
if db and not skip_cache:
|
||||
cached = await self._get_cached(db, ioc_value, ioc_type)
|
||||
if cached:
|
||||
logger.info(f"Cache hit for {ioc_type.value}:{ioc_value} ({len(cached)} results)")
|
||||
return cached
|
||||
|
||||
# Query all applicable providers in parallel
|
||||
tasks = []
|
||||
for provider in self.configured_providers:
|
||||
# Skip providers that don't support this IOC type
|
||||
if ioc_type in (IOCType.DOMAIN,) and provider.name in ("abuseipdb", "shodan"):
|
||||
continue
|
||||
if ioc_type == IOCType.IP and provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.IP:
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.URL:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
|
||||
if tasks:
|
||||
results = list(await asyncio.gather(*tasks, return_exceptions=False))
|
||||
|
||||
# Cache results
|
||||
if db and results:
|
||||
await self._cache_results(db, results)
|
||||
|
||||
return results
|
||||
|
||||
async def enrich_batch(
|
||||
self,
|
||||
iocs: list[tuple[str, IOCType]],
|
||||
db: AsyncSession | None = None,
|
||||
concurrency: int = 3,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Enrich a batch of IOCs with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
all_results: dict[str, list[EnrichmentResultData]] = {}
|
||||
|
||||
async def _enrich_one(value: str, ioc_type: IOCType):
|
||||
async with sem:
|
||||
result = await self.enrich_ioc(value, ioc_type, db=db)
|
||||
all_results[value] = result
|
||||
|
||||
tasks = [_enrich_one(v, t) for v, t in iocs]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return all_results
|
||||
|
||||
async def enrich_dataset_iocs(
|
||||
self,
|
||||
rows: list[dict],
|
||||
ioc_columns: dict,
|
||||
db: AsyncSession | None = None,
|
||||
max_iocs: int = 50,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Auto-enrich IOCs found in a dataset.
|
||||
|
||||
Extracts unique IOC values from the identified columns and enriches them.
|
||||
"""
|
||||
iocs_to_enrich: list[tuple[str, IOCType]] = []
|
||||
seen = set()
|
||||
|
||||
for col_name, col_type in ioc_columns.items():
|
||||
ioc_type = self._map_column_type(col_type)
|
||||
if not ioc_type:
|
||||
continue
|
||||
|
||||
for row in rows:
|
||||
value = row.get(col_name, "")
|
||||
if value and value not in seen:
|
||||
seen.add(value)
|
||||
iocs_to_enrich.append((str(value), ioc_type))
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if iocs_to_enrich:
|
||||
return await self.enrich_batch(iocs_to_enrich, db=db)
|
||||
return {}
|
||||
|
||||
async def _get_cached(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
) -> list[EnrichmentResultData] | None:
|
||||
"""Check for cached enrichment results."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=self.CACHE_TTL_HOURS)
|
||||
stmt = (
|
||||
select(EnrichmentDB)
|
||||
.where(
|
||||
EnrichmentDB.ioc_value == ioc_value,
|
||||
EnrichmentDB.ioc_type == ioc_type.value,
|
||||
EnrichmentDB.cached_at >= cutoff,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
cached = result.scalars().all()
|
||||
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
return [
|
||||
EnrichmentResultData(
|
||||
ioc_value=c.ioc_value,
|
||||
ioc_type=IOCType(c.ioc_type),
|
||||
source=c.source,
|
||||
verdict=Verdict(c.verdict),
|
||||
score=c.score or 0.0,
|
||||
raw_data=c.raw_data or {},
|
||||
tags=c.tags or [],
|
||||
country=c.country or "",
|
||||
asn=c.asn or "",
|
||||
org=c.org or "",
|
||||
)
|
||||
for c in cached
|
||||
]
|
||||
|
||||
async def _cache_results(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
results: list[EnrichmentResultData],
|
||||
):
|
||||
"""Cache enrichment results in the database."""
|
||||
for r in results:
|
||||
if r.verdict == Verdict.ERROR:
|
||||
continue # Don't cache errors
|
||||
entry = EnrichmentDB(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
raw_data=r.raw_data,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
)
|
||||
db.add(entry)
|
||||
try:
|
||||
await db.flush()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache enrichment: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _map_column_type(col_type: str) -> IOCType | None:
|
||||
"""Map column type from normalizer to IOCType."""
|
||||
mapping = {
|
||||
"ip": IOCType.IP,
|
||||
"ip_address": IOCType.IP,
|
||||
"src_ip": IOCType.IP,
|
||||
"dst_ip": IOCType.IP,
|
||||
"domain": IOCType.DOMAIN,
|
||||
"hash_md5": IOCType.HASH_MD5,
|
||||
"hash_sha1": IOCType.HASH_SHA1,
|
||||
"hash_sha256": IOCType.HASH_SHA256,
|
||||
"url": IOCType.URL,
|
||||
}
|
||||
return mapping.get(col_type)
|
||||
|
||||
async def cleanup(self):
|
||||
for provider in self.providers:
|
||||
await provider.cleanup()
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return enrichment engine status."""
|
||||
return {
|
||||
"providers": {
|
||||
p.name: {"configured": p.is_configured}
|
||||
for p in self.providers
|
||||
},
|
||||
"cache_ttl_hours": self.CACHE_TTL_HOURS,
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
enrichment_engine = EnrichmentEngine()
|
||||
290
backend/app/services/host_inventory.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Host Inventory Service - builds a deduplicated host-centric network view.
|
||||
|
||||
Scans all datasets in a hunt to identify unique hosts, their IPs, OS,
|
||||
logged-in users, and network connections between them.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Column-name patterns (Velociraptor + generic forensic tools) ---
|
||||
|
||||
_HOST_ID_RE = re.compile(
|
||||
r'^(client_?id|clientid|agent_?id|endpoint_?id|host_?id|sensor_?id)$', re.I)
|
||||
_FQDN_RE = re.compile(
|
||||
r'^(fqdn|fully_?qualified|computer_?name|hostname|host_?name|host|'
|
||||
r'system_?name|machine_?name|nodename|workstation)$', re.I)
|
||||
_USERNAME_RE = re.compile(
|
||||
r'^(user|username|user_?name|logon_?name|account_?name|owner|'
|
||||
r'logged_?in_?user|sam_?account_?name|samaccountname)$', re.I)
|
||||
_LOCAL_IP_RE = re.compile(
|
||||
r'^(laddr\.?ip|laddr|local_?addr(ess)?|src_?ip|source_?ip)$', re.I)
|
||||
_REMOTE_IP_RE = re.compile(
|
||||
r'^(raddr\.?ip|raddr|remote_?addr(ess)?|dst_?ip|dest_?ip)$', re.I)
|
||||
_REMOTE_PORT_RE = re.compile(
|
||||
r'^(raddr\.?port|rport|remote_?port|dst_?port|dest_?port)$', re.I)
|
||||
_OS_RE = re.compile(
|
||||
r'^(os|operating_?system|os_?version|os_?name|platform|os_?type|os_?build)$', re.I)
|
||||
_IP_VALID_RE = re.compile(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$')
|
||||
|
||||
_IGNORE_IPS = frozenset({
|
||||
'0.0.0.0', '::', '::1', '127.0.0.1', '', '-', '*', 'None', 'null',
|
||||
})
|
||||
_SYSTEM_DOMAINS = frozenset({
|
||||
'NT AUTHORITY', 'NT SERVICE', 'FONT DRIVER HOST', 'WINDOW MANAGER',
|
||||
})
|
||||
_SYSTEM_USERS = frozenset({
|
||||
'SYSTEM', 'LOCAL SERVICE', 'NETWORK SERVICE',
|
||||
'UMFD-0', 'UMFD-1', 'DWM-1', 'DWM-2', 'DWM-3',
|
||||
})
|
||||
|
||||
|
||||
def _is_valid_ip(v: str) -> bool:
|
||||
if not v or v in _IGNORE_IPS:
|
||||
return False
|
||||
return bool(_IP_VALID_RE.match(v))
|
||||
|
||||
|
||||
def _clean(v: Any) -> str:
|
||||
s = str(v or '').strip()
|
||||
return s if s and s not in ('-', 'None', 'null', '') else ''
|
||||
|
||||
|
||||
_SYSTEM_USER_RE = re.compile(
|
||||
r'^(SYSTEM|LOCAL SERVICE|NETWORK SERVICE|DWM-\d+|UMFD-\d+)$', re.I)
|
||||
|
||||
|
||||
def _extract_username(raw: str) -> str:
|
||||
"""Clean username, stripping domain prefixes and filtering system accounts."""
|
||||
if not raw:
|
||||
return ''
|
||||
name = raw.strip()
|
||||
if '\\' in name:
|
||||
domain, _, name = name.rpartition('\\')
|
||||
name = name.strip()
|
||||
if domain.strip().upper() in _SYSTEM_DOMAINS:
|
||||
if not name or _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
if _SYSTEM_USER_RE.match(name):
|
||||
return ''
|
||||
return name or ''
|
||||
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
return 'Windows 10'
|
||||
if 'W11-' in u or 'WIN11' in u:
|
||||
return 'Windows 11'
|
||||
if 'W7-' in u or 'WIN7' in u:
|
||||
return 'Windows 7'
|
||||
if 'SRV' in u or 'SERVER' in u or 'DC-' in u:
|
||||
return 'Windows Server'
|
||||
if any(k in u for k in ('LINUX', 'UBUNTU', 'CENTOS', 'RHEL', 'DEBIAN')):
|
||||
return 'Linux'
|
||||
if 'MAC' in u or 'DARWIN' in u:
|
||||
return 'macOS'
|
||||
return 'Windows'
|
||||
|
||||
|
||||
def _identify_columns(ds: Dataset) -> dict:
|
||||
norm = ds.normalized_columns or {}
|
||||
schema = ds.column_schema or {}
|
||||
raw_cols = list(schema.keys()) if schema else list(norm.keys())
|
||||
|
||||
result = {
|
||||
'host_id': [], 'fqdn': [], 'username': [],
|
||||
'local_ip': [], 'remote_ip': [], 'remote_port': [], 'os': [],
|
||||
}
|
||||
|
||||
for col in raw_cols:
|
||||
canonical = (norm.get(col) or '').lower()
|
||||
lower = col.lower()
|
||||
|
||||
if _HOST_ID_RE.match(lower) or (canonical == 'hostname' and lower not in ('hostname', 'host_name', 'host')):
|
||||
result['host_id'].append(col)
|
||||
|
||||
if _FQDN_RE.match(lower) or canonical == 'fqdn':
|
||||
result['fqdn'].append(col)
|
||||
|
||||
if _USERNAME_RE.match(lower) or canonical in ('username', 'user'):
|
||||
result['username'].append(col)
|
||||
|
||||
if _LOCAL_IP_RE.match(lower):
|
||||
result['local_ip'].append(col)
|
||||
elif _REMOTE_IP_RE.match(lower):
|
||||
result['remote_ip'].append(col)
|
||||
|
||||
if _REMOTE_PORT_RE.match(lower):
|
||||
result['remote_port'].append(col)
|
||||
|
||||
if _OS_RE.match(lower) or canonical == 'os':
|
||||
result['os'].append(col)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
|
||||
Returns dict with 'hosts', 'connections', and 'stats'.
|
||||
Each host has: id, hostname, fqdn, client_id, ips, os, users, datasets, row_count.
|
||||
"""
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
all_datasets = ds_result.scalars().all()
|
||||
|
||||
if not all_datasets:
|
||||
return {"hosts": [], "connections": [], "stats": {
|
||||
"total_hosts": 0, "total_datasets_scanned": 0,
|
||||
"total_rows_scanned": 0,
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
while True:
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
fqdn = _clean(data.get(c))
|
||||
if fqdn:
|
||||
break
|
||||
client_id = ''
|
||||
for c in cols['host_id']:
|
||||
client_id = _clean(data.get(c))
|
||||
if client_id:
|
||||
break
|
||||
|
||||
if not fqdn and not client_id:
|
||||
continue
|
||||
|
||||
host_key = fqdn or client_id
|
||||
|
||||
if host_key not in hosts:
|
||||
short = fqdn.split('.')[0] if fqdn and '.' in fqdn else fqdn
|
||||
hosts[host_key] = {
|
||||
'id': host_key,
|
||||
'hostname': short or client_id,
|
||||
'fqdn': fqdn,
|
||||
'client_id': client_id,
|
||||
'ips': set(),
|
||||
'os': '',
|
||||
'users': set(),
|
||||
'datasets': set(),
|
||||
'row_count': 0,
|
||||
}
|
||||
|
||||
h = hosts[host_key]
|
||||
h['datasets'].add(ds.name)
|
||||
h['row_count'] += 1
|
||||
if client_id and not h['client_id']:
|
||||
h['client_id'] = client_id
|
||||
|
||||
for c in cols['username']:
|
||||
u = _extract_username(_clean(data.get(c)))
|
||||
if u:
|
||||
h['users'].add(u)
|
||||
|
||||
for c in cols['local_ip']:
|
||||
ip = _clean(data.get(c))
|
||||
if _is_valid_ip(ip):
|
||||
h['ips'].add(ip)
|
||||
ip_to_host[ip] = host_key
|
||||
|
||||
for c in cols['os']:
|
||||
ov = _clean(data.get(c))
|
||||
if ov and not h['os']:
|
||||
h['os'] = ov
|
||||
|
||||
for c in cols['remote_ip']:
|
||||
rip = _clean(data.get(c))
|
||||
if _is_valid_ip(rip):
|
||||
rport = ''
|
||||
for pc in cols['remote_port']:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
|
||||
offset += batch_size
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
h['os'] = _infer_os(h['fqdn'])
|
||||
h['ips'] = sorted(h['ips'])
|
||||
h['users'] = sorted(h['users'])
|
||||
h['datasets'] = sorted(h['datasets'])
|
||||
|
||||
# Build connections, resolving IPs to host keys
|
||||
conn_list = []
|
||||
seen = set()
|
||||
for (src, dst_ip, dst_port), cnt in connections.items():
|
||||
if dst_ip in _IGNORE_IPS:
|
||||
continue
|
||||
dst_host = ip_to_host.get(dst_ip, '')
|
||||
if dst_host == src:
|
||||
continue
|
||||
key = tuple(sorted([src, dst_host or dst_ip]))
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
conn_list.append({
|
||||
'source': src,
|
||||
'target': dst_host or dst_ip,
|
||||
'target_ip': dst_ip,
|
||||
'port': dst_port,
|
||||
'count': cnt,
|
||||
})
|
||||
|
||||
host_list = sorted(hosts.values(), key=lambda x: x['row_count'], reverse=True)
|
||||
|
||||
return {
|
||||
"hosts": host_list,
|
||||
"connections": conn_list,
|
||||
"stats": {
|
||||
"total_hosts": len(host_list),
|
||||
"total_datasets_scanned": len(all_datasets),
|
||||
"datasets_with_hosts": ds_with_hosts,
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
},
|
||||
}
|
||||
198
backend/app/services/host_profiler.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Host profiler - per-host deep threat analysis via Wile heavy models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config import settings
|
||||
from app.db.engine import async_session
|
||||
from app.db.models import Dataset, DatasetRow, HostProfile, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
select(TriageResult)
|
||||
.where(TriageResult.dataset_id == dataset_id)
|
||||
.where(TriageResult.risk_score >= 3.0)
|
||||
.order_by(TriageResult.risk_score.desc())
|
||||
.limit(10)
|
||||
)
|
||||
triages = result.scalars().all()
|
||||
if not triages:
|
||||
return "No significant triage findings."
|
||||
lines = []
|
||||
for t in triages:
|
||||
lines.append(
|
||||
f"- Rows {t.row_start}-{t.row_end}: risk={t.risk_score:.1f} "
|
||||
f"verdict={t.verdict} findings={json.dumps(t.findings, default=str)[:300]}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _collect_host_data(db, hunt_id: str, hostname: str, fqdn: str | None = None) -> dict:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_data: dict[str, list[dict]] = {}
|
||||
triage_parts: list[str] = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(500)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
matching = []
|
||||
for r in rows:
|
||||
data = r.normalized_data or r.data
|
||||
row_host = (
|
||||
data.get("hostname", "") or data.get("Fqdn", "")
|
||||
or data.get("ClientId", "") or data.get("client_id", "")
|
||||
)
|
||||
if hostname.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
elif fqdn and fqdn.lower() in str(row_host).lower():
|
||||
matching.append(data)
|
||||
|
||||
if matching:
|
||||
host_data[artifact_type] = matching[:50]
|
||||
triage_info = await _get_triage_summary(db, ds.id)
|
||||
triage_parts.append(f"\n### {artifact_type} ({len(matching)} rows)\n{triage_info}")
|
||||
|
||||
return {
|
||||
"artifacts": host_data,
|
||||
"triage_summary": "\n".join(triage_parts) or "No triage data.",
|
||||
"artifact_count": sum(len(v) for v in host_data.values()),
|
||||
}
|
||||
|
||||
|
||||
async def profile_host(
|
||||
hunt_id: str, hostname: str, fqdn: str | None = None, client_id: str | None = None,
|
||||
) -> None:
|
||||
logger.info("Profiling host %s in hunt %s", hostname, hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
host_data = await _collect_host_data(db, hunt_id, hostname, fqdn)
|
||||
if host_data["artifact_count"] == 0:
|
||||
logger.info("No data found for host %s, skipping", hostname)
|
||||
return
|
||||
|
||||
system_prompt = (
|
||||
"You are a senior threat hunting analyst performing deep host analysis.\n"
|
||||
"You receive consolidated forensic artifacts and prior triage results for a single host.\n\n"
|
||||
"Provide a comprehensive host threat profile as JSON:\n"
|
||||
"- risk_score: 0.0 (clean) to 10.0 (actively compromised)\n"
|
||||
"- risk_level: low/medium/high/critical\n"
|
||||
"- suspicious_findings: list of specific concerns\n"
|
||||
"- mitre_techniques: list of MITRE ATT&CK technique IDs\n"
|
||||
"- timeline_summary: brief timeline of suspicious activity\n"
|
||||
"- analysis: detailed narrative assessment\n\n"
|
||||
"Consider: cross-artifact correlation, attack patterns, LOLBins, anomalies.\n"
|
||||
"Respond with valid JSON only."
|
||||
)
|
||||
|
||||
artifact_summary = {}
|
||||
for art_type, rows in host_data["artifacts"].items():
|
||||
artifact_summary[art_type] = [
|
||||
{k: str(v)[:150] for k, v in row.items() if v} for row in rows[:20]
|
||||
]
|
||||
|
||||
prompt = (
|
||||
f"Host: {hostname}\nFQDN: {fqdn or 'unknown'}\n\n"
|
||||
f"## Prior Triage Results\n{host_data['triage_summary']}\n\n"
|
||||
f"## Artifact Data ({host_data['artifact_count']} total rows)\n"
|
||||
f"{json.dumps(artifact_summary, indent=1, default=str)[:8000]}\n\n"
|
||||
"Provide your comprehensive host threat profile as JSON."
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(
|
||||
WILE_URL,
|
||||
json={
|
||||
"model": HEAVY_MODEL,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 4096},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
llm_text = resp.json().get("response", "")
|
||||
|
||||
from app.services.triage import _parse_llm_response
|
||||
parsed = _parse_llm_response(llm_text)
|
||||
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id,
|
||||
hostname=hostname,
|
||||
fqdn=fqdn,
|
||||
client_id=client_id,
|
||||
risk_score=float(parsed.get("risk_score", 0.0)),
|
||||
risk_level=parsed.get("risk_level", "low"),
|
||||
artifact_summary={a: len(r) for a, r in host_data["artifacts"].items()},
|
||||
timeline_summary=parsed.get("timeline_summary", ""),
|
||||
suspicious_findings=parsed.get("suspicious_findings", []),
|
||||
mitre_techniques=parsed.get("mitre_techniques", []),
|
||||
llm_analysis=parsed.get("analysis", llm_text[:5000]),
|
||||
model_used=HEAVY_MODEL,
|
||||
node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
llm_analysis=f"Error: {e}",
|
||||
model_used=HEAVY_MODEL, node_used="wile",
|
||||
)
|
||||
db.add(profile)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def profile_all_hosts(hunt_id: str) -> None:
|
||||
logger.info("Starting host profiling for hunt %s", hunt_id)
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
|
||||
hostnames: dict[str, str | None] = {}
|
||||
for ds in datasets:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow).where(DatasetRow.dataset_id == ds.id).limit(2000)
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
host = data.get("hostname") or data.get("Fqdn") or data.get("Hostname")
|
||||
if host and str(host).strip():
|
||||
h = str(host).strip()
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
async def _bounded(hostname: str, fqdn: str | None):
|
||||
async with semaphore:
|
||||
await profile_host(hunt_id, hostname, fqdn)
|
||||
|
||||
tasks = [_bounded(h, f) for h, f in hostnames.items()]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("Host profiling complete for hunt %s (%d hosts)", hunt_id, len(hostnames))
|
||||
210
backend/app/services/ioc_extractor.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""IOC extraction service extract indicators of compromise from dataset rows.
|
||||
|
||||
Identifies: IPv4/IPv6 addresses, domain names, MD5/SHA1/SHA256 hashes,
|
||||
email addresses, URLs, and file paths that look suspicious.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns
|
||||
|
||||
_IPV4 = re.compile(
|
||||
r'\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b'
|
||||
)
|
||||
_IPV6 = re.compile(r'\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b')
|
||||
_DOMAIN = re.compile(
|
||||
r'\b(?:[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)'
|
||||
r'+(?:com|net|org|io|info|biz|co|us|uk|de|ru|cn|cc|tk|xyz|top|'
|
||||
r'online|site|club|win|work|download|stream|gdn|bid|review|racing|'
|
||||
r'loan|date|faith|accountant|cricket|science|trade|party|men)\b',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_MD5 = re.compile(r'\b[0-9a-fA-F]{32}\b')
|
||||
_SHA1 = re.compile(r'\b[0-9a-fA-F]{40}\b')
|
||||
_SHA256 = re.compile(r'\b[0-9a-fA-F]{64}\b')
|
||||
_EMAIL = re.compile(r'\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b')
|
||||
_URL = re.compile(r'https?://[^\s<>"\']+', re.IGNORECASE)
|
||||
|
||||
# Private / reserved IPs to skip
|
||||
_PRIVATE_NETS = re.compile(
|
||||
r'^(10\.|172\.(1[6-9]|2\d|3[01])\.|192\.168\.|127\.|0\.|255\.)'
|
||||
)
|
||||
|
||||
PATTERNS = {
|
||||
'ipv4': _IPV4,
|
||||
'ipv6': _IPV6,
|
||||
'domain': _DOMAIN,
|
||||
'md5': _MD5,
|
||||
'sha1': _SHA1,
|
||||
'sha256': _SHA256,
|
||||
'email': _EMAIL,
|
||||
'url': _URL,
|
||||
}
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
return bool(_PRIVATE_NETS.match(ip))
|
||||
|
||||
|
||||
def extract_iocs_from_text(text: str, skip_private: bool = True) -> dict[str, set[str]]:
|
||||
"""Extract all IOC types from a block of text."""
|
||||
result: dict[str, set[str]] = defaultdict(set)
|
||||
for ioc_type, pattern in PATTERNS.items():
|
||||
for match in pattern.findall(text):
|
||||
val = match.strip().lower() if ioc_type != 'url' else match.strip()
|
||||
# Filter private IPs
|
||||
if ioc_type == 'ipv4' and skip_private and _is_private_ip(val):
|
||||
continue
|
||||
# Filter hex strings that are too generic (< 32 chars not a hash)
|
||||
result[ioc_type].add(val)
|
||||
return result
|
||||
|
||||
|
||||
async def extract_iocs_from_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession,
|
||||
max_rows: int = 5000,
|
||||
skip_private: bool = True,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Extract IOCs from all rows of a dataset.
|
||||
|
||||
Returns {ioc_type: [sorted unique values]}.
|
||||
"""
|
||||
# Load rows in batches
|
||||
all_iocs: dict[str, set[str]] = defaultdict(set)
|
||||
offset = 0
|
||||
batch_size = 500
|
||||
|
||||
while offset < max_rows:
|
||||
result = await db.execute(
|
||||
select(DatasetRow.data)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for data in rows:
|
||||
# Flatten all values to a single string for scanning
|
||||
text = ' '.join(str(v) for v in data.values()) if isinstance(data, dict) else str(data)
|
||||
batch_iocs = extract_iocs_from_text(text, skip_private)
|
||||
for ioc_type, values in batch_iocs.items():
|
||||
all_iocs[ioc_type].update(values)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Convert sets to sorted lists
|
||||
return {k: sorted(v) for k, v in all_iocs.items() if v}
|
||||
|
||||
|
||||
async def extract_host_groups(
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Group all data by hostname across datasets in a hunt.
|
||||
|
||||
Returns a list of host group dicts with dataset count, total rows,
|
||||
artifact types, and time range.
|
||||
"""
|
||||
# Get all datasets for this hunt
|
||||
result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
ds_list = result.scalars().all()
|
||||
if not ds_list:
|
||||
return []
|
||||
|
||||
# Known host columns (check normalized data first, then raw)
|
||||
HOST_COLS = [
|
||||
'hostname', 'host', 'computer_name', 'computername', 'system',
|
||||
'machine', 'device_name', 'devicename', 'endpoint',
|
||||
'ClientId', 'Fqdn', 'client_id', 'fqdn',
|
||||
]
|
||||
|
||||
hosts: dict[str, dict] = {}
|
||||
|
||||
for ds in ds_list:
|
||||
# Sample first few rows to find host column
|
||||
sample_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.limit(5)
|
||||
)
|
||||
samples = sample_result.all()
|
||||
if not samples:
|
||||
continue
|
||||
|
||||
# Find which host column exists
|
||||
host_col = None
|
||||
for row_data, norm_data in samples:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
for col in HOST_COLS:
|
||||
if col in check and check[col]:
|
||||
host_col = col
|
||||
break
|
||||
if host_col:
|
||||
break
|
||||
|
||||
if not host_col:
|
||||
continue
|
||||
|
||||
# Count rows per host in this dataset
|
||||
all_rows_result = await db.execute(
|
||||
select(DatasetRow.data, DatasetRow.normalized_data)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
)
|
||||
all_rows = all_rows_result.all()
|
||||
for row_data, norm_data in all_rows:
|
||||
check = norm_data if norm_data else row_data
|
||||
if not isinstance(check, dict):
|
||||
continue
|
||||
host_val = check.get(host_col, '')
|
||||
if not host_val or not isinstance(host_val, str):
|
||||
continue
|
||||
host_val = host_val.strip()
|
||||
if not host_val:
|
||||
continue
|
||||
|
||||
if host_val not in hosts:
|
||||
hosts[host_val] = {
|
||||
'hostname': host_val,
|
||||
'dataset_ids': set(),
|
||||
'total_rows': 0,
|
||||
'artifact_types': set(),
|
||||
'first_seen': None,
|
||||
'last_seen': None,
|
||||
}
|
||||
hosts[host_val]['dataset_ids'].add(ds.id)
|
||||
hosts[host_val]['total_rows'] += 1
|
||||
if ds.artifact_type:
|
||||
hosts[host_val]['artifact_types'].add(ds.artifact_type)
|
||||
|
||||
# Convert to output format
|
||||
result_list = []
|
||||
for h in sorted(hosts.values(), key=lambda x: x['total_rows'], reverse=True):
|
||||
result_list.append({
|
||||
'hostname': h['hostname'],
|
||||
'dataset_count': len(h['dataset_ids']),
|
||||
'total_rows': h['total_rows'],
|
||||
'artifact_types': sorted(h['artifact_types']),
|
||||
'first_seen': None, # TODO: extract from timestamp columns
|
||||
'last_seen': None,
|
||||
'risk_score': None, # TODO: link to host profiles
|
||||
})
|
||||
|
||||
return result_list
|
||||
316
backend/app/services/job_queue.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
TRIAGE = "triage"
|
||||
HOST_PROFILE = "host_profile"
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
id: str
|
||||
job_type: JobType
|
||||
status: JobStatus = JobStatus.QUEUED
|
||||
progress: float = 0.0 # 0-100
|
||||
message: str = ""
|
||||
result: Any = None
|
||||
error: str | None = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
started_at: float | None = None
|
||||
completed_at: float | None = None
|
||||
params: dict = field(default_factory=dict)
|
||||
_cancel_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
end = self.completed_at or time.time()
|
||||
start = self.started_at or self.created_at
|
||||
return int((end - start) * 1000)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"job_type": self.job_type.value,
|
||||
"status": self.status.value,
|
||||
"progress": round(self.progress, 1),
|
||||
"message": self.message,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
"elapsed_ms": self.elapsed_ms,
|
||||
"params": self.params,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def cancel(self):
|
||||
self._cancel_event.set()
|
||||
self.status = JobStatus.CANCELLED
|
||||
self.completed_at = time.time()
|
||||
self.message = "Cancelled by user"
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._max_workers = max_workers
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
return job
|
||||
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
return False
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
return False
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
if job_type:
|
||||
jobs = [j for j in jobs if j.job_type == job_type]
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
return {
|
||||
"total": len(self._jobs),
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
}
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
job_id = await asyncio.wait_for(self._queue.get(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
job = self._jobs.get(job_id)
|
||||
if not job or job.is_cancelled:
|
||||
continue
|
||||
|
||||
handler = self._handlers.get(job.job_type)
|
||||
if not handler:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = f"No handler for {job.job_type.value}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(f"No handler for job type {job.job_type.value}")
|
||||
continue
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
job.message = "Running..."
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
result = await handler(job)
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.progress = 100.0
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
if hostname:
|
||||
job.message = f"Profiling host {hostname}"
|
||||
await profile_host(hunt_id, hostname)
|
||||
return {"hostname": hostname}
|
||||
else:
|
||||
job.message = f"Profiling all hosts in hunt {hunt_id}"
|
||||
await profile_all_hosts(hunt_id)
|
||||
return {"hunt_id": hunt_id}
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
report = await generate_report(hunt_id)
|
||||
return {"report_id": report.id if report else None}
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
threshold = job.params.get("threshold", 0.35)
|
||||
job.message = f"Detecting anomalies in dataset {dataset_id}"
|
||||
results = await detect_anomalies(dataset_id, k=k, outlier_threshold=threshold)
|
||||
return {"count": len(results) if results else 0}
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
mode = job.params.get("mode", "quick")
|
||||
job.message = f"Querying dataset {dataset_id}"
|
||||
answer = await query_dataset(dataset_id, question, mode)
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
145
backend/app/services/keyword_defaults.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Default AUP keyword themes and their seed keywords.
|
||||
|
||||
Called once on startup — only inserts themes that don't already exist,
|
||||
so user edits are never overwritten.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Default themes + keywords ─────────────────────────────────────────
|
||||
|
||||
DEFAULTS: dict[str, dict] = {
|
||||
"Gambling": {
|
||||
"color": "#f44336",
|
||||
"keywords": [
|
||||
"poker", "casino", "blackjack", "roulette", "sportsbook",
|
||||
"sports betting", "bet365", "draftkings", "fanduel", "bovada",
|
||||
"betonline", "mybookie", "slots", "slot machine", "parlay",
|
||||
"wager", "bookie", "betway", "888casino", "pokerstars",
|
||||
"william hill", "ladbrokes", "betfair", "unibet", "pinnacle",
|
||||
],
|
||||
},
|
||||
"Gaming": {
|
||||
"color": "#9c27b0",
|
||||
"keywords": [
|
||||
"steam", "steamcommunity", "steampowered", "epic games",
|
||||
"epicgames", "origin.com", "battle.net", "blizzard",
|
||||
"roblox", "minecraft", "fortnite", "valorant", "league of legends",
|
||||
"twitch", "twitch.tv", "discord", "discord.gg", "xbox live",
|
||||
"playstation network", "gog.com", "itch.io", "gamepass",
|
||||
"riot games", "ubisoft", "ea.com",
|
||||
],
|
||||
},
|
||||
"Streaming": {
|
||||
"color": "#ff9800",
|
||||
"keywords": [
|
||||
"netflix", "hulu", "disney+", "disneyplus", "hbomax",
|
||||
"amazon prime video", "peacock", "paramount+", "crunchyroll",
|
||||
"funimation", "spotify", "pandora", "soundcloud", "deezer",
|
||||
"tidal", "apple music", "youtube music", "pluto tv",
|
||||
"tubi", "vudu", "plex",
|
||||
],
|
||||
},
|
||||
"Downloads / Piracy": {
|
||||
"color": "#ff5722",
|
||||
"keywords": [
|
||||
"torrent", "bittorrent", "utorrent", "qbittorrent", "piratebay",
|
||||
"thepiratebay", "1337x", "rarbg", "yts", "kickass",
|
||||
"limewire", "frostwire", "mega.nz", "rapidshare", "mediafire",
|
||||
"zippyshare", "uploadhaven", "fitgirl", "repack", "crack",
|
||||
"keygen", "warez", "nulled", "pirate", "magnet:",
|
||||
],
|
||||
},
|
||||
"Adult Content": {
|
||||
"color": "#e91e63",
|
||||
"keywords": [
|
||||
"pornhub", "xvideos", "xhamster", "onlyfans", "chaturbate",
|
||||
"livejasmin", "brazzers", "redtube", "youporn", "xnxx",
|
||||
"porn", "xxx", "nsfw", "adult content", "cam site",
|
||||
"stripchat", "bongacams",
|
||||
],
|
||||
},
|
||||
"Social Media": {
|
||||
"color": "#2196f3",
|
||||
"keywords": [
|
||||
"facebook", "instagram", "tiktok", "snapchat", "pinterest",
|
||||
"reddit", "tumblr", "myspace", "whatsapp web", "telegram web",
|
||||
"signal web", "wechat web", "twitter.com", "x.com",
|
||||
"threads.net", "mastodon", "bluesky",
|
||||
],
|
||||
},
|
||||
"Job Search": {
|
||||
"color": "#4caf50",
|
||||
"keywords": [
|
||||
"indeed", "linkedin jobs", "glassdoor", "monster.com",
|
||||
"ziprecruiter", "careerbuilder", "dice.com", "hired.com",
|
||||
"angel.co", "wellfound", "levels.fyi", "salary.com",
|
||||
"payscale", "resume", "cover letter", "job application",
|
||||
],
|
||||
},
|
||||
"Shopping": {
|
||||
"color": "#00bcd4",
|
||||
"keywords": [
|
||||
"amazon.com", "ebay", "etsy", "walmart.com", "target.com",
|
||||
"bestbuy", "aliexpress", "wish.com", "shein", "temu",
|
||||
"wayfair", "overstock", "newegg", "zappos", "coupon",
|
||||
"promo code", "add to cart",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def seed_defaults(db: AsyncSession) -> int:
|
||||
"""Insert default themes + keywords for any theme name not already in DB.
|
||||
|
||||
Returns the number of themes inserted (0 if all already exist).
|
||||
"""
|
||||
# Rename legacy theme names
|
||||
_renames = [("Social Media (Personal)", "Social Media")]
|
||||
for old_name, new_name in _renames:
|
||||
old = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == old_name))
|
||||
if old:
|
||||
await db.execute(
|
||||
KeywordTheme.__table__.update()
|
||||
.where(KeywordTheme.name == old_name)
|
||||
.values(name=new_name)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info("Renamed AUP theme '%s' → '%s'", old_name, new_name)
|
||||
|
||||
inserted = 0
|
||||
for theme_name, meta in DEFAULTS.items():
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == theme_name)
|
||||
)
|
||||
if exists:
|
||||
continue
|
||||
|
||||
theme = KeywordTheme(
|
||||
name=theme_name,
|
||||
color=meta["color"],
|
||||
enabled=True,
|
||||
is_builtin=True,
|
||||
)
|
||||
db.add(theme)
|
||||
await db.flush() # get theme.id
|
||||
|
||||
for kw in meta["keywords"]:
|
||||
db.add(Keyword(theme_id=theme.id, value=kw))
|
||||
|
||||
inserted += 1
|
||||
logger.info("Seeded AUP theme '%s' with %d keywords", theme_name, len(meta["keywords"]))
|
||||
|
||||
if inserted:
|
||||
await db.commit()
|
||||
logger.info("Seeded %d AUP keyword themes", inserted)
|
||||
else:
|
||||
logger.debug("All default AUP themes already present")
|
||||
|
||||
return inserted
|
||||
322
backend/app/services/llm_analysis.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""LLM-powered dataset analysis — replaces manual IOC enrichment.
|
||||
|
||||
Loads dataset rows server-side, builds a concise summary, and sends it
|
||||
to Wile (70B heavy) or Roadrunner (fast) for threat analysis.
|
||||
Supports both single-dataset and hunt-wide analysis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.agents.providers_v2 import OllamaProvider
|
||||
from app.agents.router import TaskType, task_router
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AnalysisRequest(BaseModel):
|
||||
"""Request for LLM-powered analysis of a dataset."""
|
||||
dataset_id: Optional[str] = None
|
||||
hunt_id: Optional[str] = None
|
||||
question: str = Field(
|
||||
default="Perform a comprehensive threat analysis of this dataset. "
|
||||
"Identify anomalies, suspicious patterns, potential IOCs, and recommend "
|
||||
"next steps for the analyst.",
|
||||
description="Specific question or general analysis request",
|
||||
)
|
||||
mode: str = Field(default="deep", description="quick | deep")
|
||||
focus: Optional[str] = Field(
|
||||
None,
|
||||
description="Focus area: threats, anomalies, lateral_movement, exfil, persistence, recon",
|
||||
)
|
||||
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
"""LLM analysis result."""
|
||||
analysis: str = Field(..., description="Full analysis text (markdown)")
|
||||
confidence: float = Field(default=0.0, description="0-1 confidence")
|
||||
key_findings: list[str] = Field(default_factory=list)
|
||||
iocs_identified: list[dict] = Field(default_factory=list)
|
||||
recommended_actions: list[str] = Field(default_factory=list)
|
||||
mitre_techniques: list[str] = Field(default_factory=list)
|
||||
risk_score: int = Field(default=0, description="0-100 risk score")
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
rows_analyzed: int = 0
|
||||
dataset_summary: str = ""
|
||||
|
||||
|
||||
# ── Analysis prompts ──────────────────────────────────────────────────
|
||||
|
||||
ANALYSIS_SYSTEM = """You are an expert threat hunter and incident response analyst.
|
||||
You are analyzing CSV log data from forensic tools (Velociraptor, Sysmon, etc.).
|
||||
|
||||
Your task: Perform deep threat analysis of the data provided and produce actionable findings.
|
||||
|
||||
RESPOND WITH VALID JSON ONLY:
|
||||
{
|
||||
"analysis": "Detailed markdown analysis with headers and bullet points",
|
||||
"confidence": 0.85,
|
||||
"key_findings": ["Finding 1", "Finding 2"],
|
||||
"iocs_identified": [{"type": "ip", "value": "1.2.3.4", "context": "C2 traffic"}],
|
||||
"recommended_actions": ["Action 1", "Action 2"],
|
||||
"mitre_techniques": ["T1059.001 - PowerShell", "T1071 - Application Layer Protocol"],
|
||||
"risk_score": 65
|
||||
}
|
||||
"""
|
||||
|
||||
FOCUS_PROMPTS = {
|
||||
"threats": "Focus on identifying active threats, malware indicators, and attack patterns.",
|
||||
"anomalies": "Focus on statistical anomalies, outliers, and unusual behavior patterns.",
|
||||
"lateral_movement": "Focus on evidence of lateral movement: PsExec, WMI, RDP, SMB, pass-the-hash.",
|
||||
"exfil": "Focus on data exfiltration indicators: large transfers, DNS tunneling, unusual destinations.",
|
||||
"persistence": "Focus on persistence mechanisms: scheduled tasks, services, registry, startup items.",
|
||||
"recon": "Focus on reconnaissance activity: scanning, enumeration, discovery commands.",
|
||||
}
|
||||
|
||||
|
||||
# ── Data summarizer ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def summarize_dataset_rows(
|
||||
rows: list[dict],
|
||||
columns: list[str] | None = None,
|
||||
max_sample: int = 20,
|
||||
max_chars: int = 6000,
|
||||
) -> str:
|
||||
"""Build a concise text summary of dataset rows for LLM consumption.
|
||||
|
||||
Includes:
|
||||
- Column headers and types
|
||||
- Statistical summary (unique values, top values per column)
|
||||
- Sample rows (first N)
|
||||
- Detected patterns of interest
|
||||
"""
|
||||
if not rows:
|
||||
return "Empty dataset — no rows to analyze."
|
||||
|
||||
cols = columns or list(rows[0].keys())
|
||||
n_rows = len(rows)
|
||||
|
||||
parts: list[str] = []
|
||||
parts.append(f"## Dataset Summary: {n_rows} rows, {len(cols)} columns")
|
||||
parts.append(f"Columns: {', '.join(cols)}")
|
||||
|
||||
# Per-column stats
|
||||
parts.append("\n### Column Statistics:")
|
||||
for col in cols[:30]: # limit to first 30 cols
|
||||
values = [str(r.get(col, "")) for r in rows if r.get(col) not in (None, "", "N/A")]
|
||||
if not values:
|
||||
continue
|
||||
unique = len(set(values))
|
||||
counter = Counter(values)
|
||||
top3 = counter.most_common(3)
|
||||
top_str = ", ".join(f"{v} ({c}x)" for v, c in top3)
|
||||
parts.append(f"- **{col}**: {len(values)} non-null, {unique} unique. Top: {top_str}")
|
||||
|
||||
# Sample rows
|
||||
sample = rows[:max_sample]
|
||||
parts.append(f"\n### Sample Rows (first {len(sample)}):")
|
||||
for i, row in enumerate(sample):
|
||||
row_str = " | ".join(f"{k}={v}" for k, v in row.items() if v not in (None, "", "N/A"))
|
||||
parts.append(f"{i+1}. {row_str}")
|
||||
|
||||
# Detect interesting patterns
|
||||
patterns: list[str] = []
|
||||
all_cmds = [str(r.get("command_line", "")).lower() for r in rows if r.get("command_line")]
|
||||
sus_cmds = [c for c in all_cmds if any(
|
||||
k in c for k in ["powershell -enc", "certutil", "bitsadmin", "mshta",
|
||||
"regsvr32", "invoke-", "mimikatz", "psexec"]
|
||||
)]
|
||||
if sus_cmds:
|
||||
patterns.append(f"⚠️ {len(sus_cmds)} suspicious command lines detected")
|
||||
|
||||
all_ips = [str(r.get("dst_ip", "")) for r in rows if r.get("dst_ip")]
|
||||
ext_ips = [ip for ip in all_ips if ip and not ip.startswith(("10.", "192.168.", "172.", "127."))]
|
||||
if ext_ips:
|
||||
unique_ext = len(set(ext_ips))
|
||||
patterns.append(f"🌐 {unique_ext} unique external destination IPs")
|
||||
|
||||
if patterns:
|
||||
parts.append("\n### Detected Patterns:")
|
||||
for p in patterns:
|
||||
parts.append(f"- {p}")
|
||||
|
||||
text = "\n".join(parts)
|
||||
if len(text) > max_chars:
|
||||
text = text[:max_chars] + "\n... (truncated)"
|
||||
return text
|
||||
|
||||
|
||||
# ── LLM analysis engine ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_llm_analysis(
|
||||
rows: list[dict],
|
||||
request: AnalysisRequest,
|
||||
dataset_name: str = "unknown",
|
||||
) -> AnalysisResult:
|
||||
"""Run LLM analysis on dataset rows."""
|
||||
start = time.monotonic()
|
||||
|
||||
# Build summary
|
||||
summary = summarize_dataset_rows(rows)
|
||||
|
||||
# Route to appropriate model
|
||||
task_type = TaskType.DEEP_ANALYSIS if request.mode == "deep" else TaskType.QUICK_CHAT
|
||||
decision = task_router.route(task_type)
|
||||
|
||||
# Build prompt
|
||||
focus_text = FOCUS_PROMPTS.get(request.focus or "", "")
|
||||
prompt = f"""Analyze the following forensic dataset from '{dataset_name}'.
|
||||
|
||||
{focus_text}
|
||||
|
||||
Analyst question: {request.question}
|
||||
|
||||
{summary}
|
||||
"""
|
||||
|
||||
# Enrich with SANS RAG
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
request.question,
|
||||
investigation_context=f"Analyzing {len(rows)} rows from {dataset_name}",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = task_router.get_provider(decision)
|
||||
messages = [
|
||||
{"role": "system", "content": ANALYSIS_SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
provider.generate(
|
||||
prompt=prompt,
|
||||
system=ANALYSIS_SYSTEM,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS * 2, # longer for analysis
|
||||
temperature=0.3,
|
||||
),
|
||||
timeout=300, # 5 min hard limit
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("LLM analysis timed out after 300s")
|
||||
return AnalysisResult(
|
||||
analysis="Analysis timed out after 5 minutes. Try a smaller dataset or 'quick' mode.",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM analysis failed: {e}")
|
||||
return AnalysisResult(
|
||||
analysis=f"Analysis failed: {str(e)}",
|
||||
model_used=decision.model,
|
||||
node_used=decision.node,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
rows_analyzed=len(rows),
|
||||
dataset_summary=summary,
|
||||
)
|
||||
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse JSON response
|
||||
result = _parse_analysis(raw)
|
||||
result.model_used = decision.model
|
||||
result.node_used = decision.node
|
||||
result.latency_ms = elapsed
|
||||
result.rows_analyzed = len(rows)
|
||||
result.dataset_summary = summary
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_analysis(raw) -> AnalysisResult:
|
||||
"""Try to parse LLM output as JSON, fall back to plain text.
|
||||
|
||||
raw may be:
|
||||
- A dict from OllamaProvider.generate() with key "response" containing LLM text
|
||||
- A plain string from other providers
|
||||
"""
|
||||
# Ollama provider returns {"response": "<llm text>", "model": ..., ...}
|
||||
if isinstance(raw, dict):
|
||||
text = raw.get("response") or raw.get("analysis") or str(raw)
|
||||
logger.info(f"_parse_analysis: extracted text from dict, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
else:
|
||||
text = str(raw)
|
||||
logger.info(f"_parse_analysis: raw is str, len={len(text)}, first 200 chars: {text[:200]}")
|
||||
|
||||
text = text.strip()
|
||||
|
||||
# Strip markdown code fences
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
lines = [l for l in lines if not l.strip().startswith("```")]
|
||||
text = "\n".join(lines).strip()
|
||||
|
||||
# Try direct JSON parse first
|
||||
for candidate in _extract_json_candidates(text):
|
||||
try:
|
||||
data = json.loads(candidate)
|
||||
if isinstance(data, dict):
|
||||
logger.info(f"_parse_analysis: parsed JSON OK, keys={list(data.keys())}")
|
||||
return AnalysisResult(
|
||||
analysis=data.get("analysis", text),
|
||||
confidence=float(data.get("confidence", 0.5)),
|
||||
key_findings=data.get("key_findings", []),
|
||||
iocs_identified=data.get("iocs_identified", []),
|
||||
recommended_actions=data.get("recommended_actions", []),
|
||||
mitre_techniques=data.get("mitre_techniques", []),
|
||||
risk_score=int(data.get("risk_score", 0)),
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"_parse_analysis: JSON parse failed: {e}, candidate len={len(candidate)}, first 100: {candidate[:100]}")
|
||||
continue
|
||||
|
||||
# Fallback: plain text
|
||||
logger.warning(f"_parse_analysis: all JSON parse attempts failed, falling back to plain text")
|
||||
return AnalysisResult(
|
||||
analysis=text,
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
|
||||
def _extract_json_candidates(text: str):
|
||||
"""Yield JSON candidate strings from text, trying progressively more aggressive extraction."""
|
||||
import re
|
||||
|
||||
# 1. The whole text as-is
|
||||
yield text
|
||||
|
||||
# 2. Find outermost { ... } block
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start != -1 and end > start:
|
||||
block = text[start:end + 1]
|
||||
yield block
|
||||
|
||||
# 3. Try to fix common LLM JSON issues:
|
||||
# - trailing commas before ] or }
|
||||
fixed = re.sub(r',\s*([}\]])', r'\1', block)
|
||||
if fixed != block:
|
||||
yield fixed
|
||||