mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
66
.env.example
66
.env.example
@@ -1,27 +1,53 @@
|
||||
# Docker environment configuration
|
||||
# Copy this to .env and customize for your deployment
|
||||
# ── ThreatHunt Configuration ──────────────────────────────────────────
|
||||
# All backend env vars are prefixed with TH_ and match AppConfig field names.
|
||||
# Copy this file to .env and adjust values.
|
||||
|
||||
# Agent Configuration
|
||||
# Choose one: local, networked, online, auto
|
||||
THREAT_HUNT_AGENT_PROVIDER=auto
|
||||
# ── General ───────────────────────────────────────────────────────────
|
||||
TH_DEBUG=false
|
||||
|
||||
# Local Provider (on-device or on-prem models)
|
||||
# THREAT_HUNT_LOCAL_MODEL_PATH=/models/model.gguf
|
||||
# ── Database ──────────────────────────────────────────────────────────
|
||||
# SQLite for local dev (zero-config):
|
||||
TH_DATABASE_URL=sqlite+aiosqlite:///./threathunt.db
|
||||
# PostgreSQL for production:
|
||||
# TH_DATABASE_URL=postgresql+asyncpg://threathunt:password@localhost:5432/threathunt
|
||||
|
||||
# Networked Provider (shared internal inference service)
|
||||
# THREAT_HUNT_NETWORKED_ENDPOINT=http://inference-service:5000
|
||||
# THREAT_HUNT_NETWORKED_KEY=api-key-here
|
||||
# ── CORS ──────────────────────────────────────────────────────────────
|
||||
TH_ALLOWED_ORIGINS=http://localhost:3000,http://localhost:8000
|
||||
|
||||
# Online Provider (external hosted APIs)
|
||||
# THREAT_HUNT_ONLINE_API_KEY=sk-your-api-key
|
||||
# THREAT_HUNT_ONLINE_PROVIDER=openai
|
||||
# THREAT_HUNT_ONLINE_MODEL=gpt-3.5-turbo
|
||||
# ── File uploads ──────────────────────────────────────────────────────
|
||||
TH_MAX_UPLOAD_SIZE_MB=500
|
||||
|
||||
# Agent Behavior
|
||||
THREAT_HUNT_AGENT_MAX_TOKENS=1024
|
||||
THREAT_HUNT_AGENT_REASONING=true
|
||||
THREAT_HUNT_AGENT_HISTORY_LENGTH=10
|
||||
THREAT_HUNT_AGENT_FILTER_SENSITIVE=true
|
||||
# ── LLM Cluster (Wile & Roadrunner) ──────────────────────────────────
|
||||
TH_OPENWEBUI_URL=https://ai.guapo613.beer
|
||||
TH_OPENWEBUI_API_KEY=
|
||||
TH_WILE_HOST=100.110.190.12
|
||||
TH_WILE_OLLAMA_PORT=11434
|
||||
TH_ROADRUNNER_HOST=100.110.190.11
|
||||
TH_ROADRUNNER_OLLAMA_PORT=11434
|
||||
|
||||
# Frontend
|
||||
# ── Default models (auto-selected by TaskRouter) ─────────────────────
|
||||
TH_DEFAULT_FAST_MODEL=llama3.1:latest
|
||||
TH_DEFAULT_HEAVY_MODEL=llama3.1:70b-instruct-q4_K_M
|
||||
TH_DEFAULT_CODE_MODEL=qwen2.5-coder:32b
|
||||
TH_DEFAULT_VISION_MODEL=llama3.2-vision:11b
|
||||
TH_DEFAULT_EMBEDDING_MODEL=bge-m3:latest
|
||||
|
||||
# ── Agent behaviour ──────────────────────────────────────────────────
|
||||
TH_AGENT_MAX_TOKENS=2048
|
||||
TH_AGENT_TEMPERATURE=0.3
|
||||
TH_AGENT_HISTORY_LENGTH=10
|
||||
TH_FILTER_SENSITIVE_DATA=true
|
||||
|
||||
# ── Enrichment API keys (optional) ───────────────────────────────────
|
||||
TH_VIRUSTOTAL_API_KEY=
|
||||
TH_ABUSEIPDB_API_KEY=
|
||||
TH_SHODAN_API_KEY=
|
||||
|
||||
# ── Auth ─────────────────────────────────────────────────────────────
|
||||
TH_JWT_SECRET=CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET
|
||||
TH_JWT_ACCESS_TOKEN_MINUTES=60
|
||||
TH_JWT_REFRESH_TOKEN_DAYS=7
|
||||
|
||||
# ── Frontend ─────────────────────────────────────────────────────────
|
||||
REACT_APP_API_URL=http://localhost:8000
|
||||
|
||||
|
||||
56
.gitignore
vendored
Normal file
56
.gitignore
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
# ── Python ────────────────────────────────────
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
.eggs/
|
||||
|
||||
# ── Virtual environments ─────────────────────
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
|
||||
# ── IDE / Editor ─────────────────────────────
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# ── OS ────────────────────────────────────────
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# ── Environment / Secrets ────────────────────
|
||||
.env
|
||||
*.env.local
|
||||
|
||||
# ── Database ─────────────────────────────────
|
||||
*.db
|
||||
*.sqlite3
|
||||
|
||||
# ── Uploads ──────────────────────────────────
|
||||
uploads/
|
||||
|
||||
# ── Node / Frontend ──────────────────────────
|
||||
node_modules/
|
||||
frontend/build/
|
||||
frontend/.env.local
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# ── Docker ───────────────────────────────────
|
||||
docker-compose.override.yml
|
||||
|
||||
# ── Test / Coverage ──────────────────────────
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
|
||||
# ── Alembic ──────────────────────────────────
|
||||
alembic/versions/*.pyc
|
||||
@@ -1,11 +1,11 @@
|
||||
# ThreatHunt Backend API - Python 3.11
|
||||
FROM python:3.11-slim
|
||||
# ThreatHunt Backend API - Python 3.13
|
||||
FROM python:3.13-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
gcc curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
@@ -17,16 +17,16 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
# Copy backend code
|
||||
COPY backend/ .
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
|
||||
# Create non-root user & data directory
|
||||
RUN useradd -m -u 1000 appuser && mkdir -p /app/data && chown -R appuser:appuser /app
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import requests; requests.get('http://localhost:8000/api/agent/health')"
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/ || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["python", "run.py"]
|
||||
# Run Alembic migrations then start Uvicorn
|
||||
CMD ["sh", "-c", "python -m alembic upgrade head && python run.py"]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# ThreatHunt Frontend - Node.js React
|
||||
FROM node:18-alpine AS builder
|
||||
FROM node:20-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -17,20 +17,14 @@ COPY frontend/tsconfig.json ./
|
||||
# Build application
|
||||
RUN npm run build
|
||||
|
||||
# Production stage
|
||||
FROM node:18-alpine
|
||||
# Production stage — nginx reverse-proxy + static files
|
||||
FROM nginx:alpine
|
||||
|
||||
WORKDIR /app
|
||||
# Copy built React app
|
||||
COPY --from=builder /app/build /usr/share/nginx/html
|
||||
|
||||
# Install serve to serve the static files
|
||||
RUN npm install -g serve
|
||||
|
||||
# Copy built application from builder
|
||||
COPY --from=builder /app/build ./build
|
||||
|
||||
# Create non-root user
|
||||
RUN addgroup -g 1000 appuser && adduser -D -u 1000 -G appuser appuser
|
||||
USER appuser
|
||||
# Copy custom nginx config (proxies /api to backend)
|
||||
COPY frontend/nginx.conf /etc/nginx/conf.d/default.conf
|
||||
|
||||
# Expose port
|
||||
EXPOSE 3000
|
||||
@@ -39,5 +33,4 @@ EXPOSE 3000
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD wget --quiet --tries=1 --spider http://localhost:3000/ || exit 1
|
||||
|
||||
# Serve application
|
||||
CMD ["serve", "-s", "build", "-l", "3000"]
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
|
||||
21
SKILLS/00-operating-model.md
Normal file
21
SKILLS/00-operating-model.md
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
# Operating Model
|
||||
|
||||
## Default cadence
|
||||
- Prefer iterative progress over big bangs.
|
||||
- Keep diffs small: target ≤ 300 changed lines per PR unless justified.
|
||||
- Update tests/docs as part of the same change when possible.
|
||||
|
||||
## Working agreement
|
||||
- Start with a PLAN for non-trivial tasks.
|
||||
- Implement the smallest slice that satisfies acceptance criteria.
|
||||
- Verify via DoD.
|
||||
- Write a crisp PR summary: what changed, why, and how verified.
|
||||
|
||||
## Stop conditions (plan first)
|
||||
Stop and produce a PLAN (do not code yet) if:
|
||||
- scope is unclear
|
||||
- more than 3 files will change
|
||||
- data model changes
|
||||
- auth/security boundaries
|
||||
- performance-critical paths
|
||||
36
SKILLS/05-agent-taxonomy.md
Normal file
36
SKILLS/05-agent-taxonomy.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Agent Types & Roles (Practical Taxonomy)
|
||||
|
||||
Use this skill to choose the *right* kind of agent workflow for the job.
|
||||
|
||||
## Common agent "types" (in practice)
|
||||
|
||||
### 1) Chat assistant (no tools)
|
||||
Best for: explanations, brainstorming, small edits.
|
||||
Risk: can hallucinate; no grounding in repo state.
|
||||
|
||||
### 2) Tool-using single agent
|
||||
Best for: well-scoped tasks where the agent can read/write files and run commands.
|
||||
Key control: strict DoD gates + minimal permissions.
|
||||
|
||||
### 3) Planner + Executor (2-role pattern)
|
||||
Best for: medium complexity work (multi-file changes, feature work).
|
||||
Flow: Planner writes plan + acceptance criteria → Executor implements → Reviewer checks.
|
||||
|
||||
### 4) Multi-agent (specialists)
|
||||
Best for: bigger features with separable workstreams (UI, backend, docs, tests).
|
||||
Rule: isolate context per role; use separate branches/worktrees.
|
||||
|
||||
### 5) Supervisor / orchestrator
|
||||
Best for: long-running workflows with checkpoints (pipelines, report generation, PAD docs).
|
||||
Rule: supervisor delegates, enforces gates, and composes final output.
|
||||
|
||||
## Decision rules (fast)
|
||||
- If you can describe it in ≤ 5 steps → single tool-using agent.
|
||||
- If you need tradeoffs/design → Planner + Executor.
|
||||
- If UI + backend + docs/tests all move → multi-agent specialists.
|
||||
- If it's a pipeline that runs repeatedly → orchestrator.
|
||||
|
||||
## Guardrails (always)
|
||||
- DoD is the truth gate.
|
||||
- Separate branches/worktrees for parallel work.
|
||||
- Log decisions + commands in AGENT_LOG.md.
|
||||
24
SKILLS/10-definition-of-done.md
Normal file
24
SKILLS/10-definition-of-done.md
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
# Definition of Done (DoD)
|
||||
|
||||
A change is "done" only when:
|
||||
|
||||
## Code correctness
|
||||
- Builds successfully (if applicable)
|
||||
- Tests pass
|
||||
- Linting/formatting passes
|
||||
- Types/checks pass (if applicable)
|
||||
|
||||
## Quality
|
||||
- No new warnings introduced
|
||||
- Edge cases handled (inputs validated, errors meaningful)
|
||||
- Hot paths not regressed (if applicable)
|
||||
|
||||
## Hygiene
|
||||
- No secrets committed
|
||||
- Docs updated if behavior or usage changed
|
||||
- PR summary includes verification steps
|
||||
|
||||
## Commands
|
||||
- macOS/Linux: `./scripts/dod.sh`
|
||||
- Windows: `\scripts\dod.ps1`
|
||||
16
SKILLS/20-repo-map.md
Normal file
16
SKILLS/20-repo-map.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# Repo Mapping Skill
|
||||
|
||||
When entering a repo:
|
||||
1) Read README.md
|
||||
2) Identify entrypoints (app main / server startup / CLI)
|
||||
3) Identify config (env vars, .env.example, config files)
|
||||
4) Identify test/lint scripts (package.json, pyproject.toml, Makefile, etc.)
|
||||
5) Write a 10-line "repo map" in the PLAN before changing code
|
||||
|
||||
Output format:
|
||||
- Purpose:
|
||||
- Key modules:
|
||||
- Data flow:
|
||||
- Commands:
|
||||
- Risks:
|
||||
20
SKILLS/25-algorithms-performance.md
Normal file
20
SKILLS/25-algorithms-performance.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Algorithms & Performance
|
||||
|
||||
Use this skill when performance matters (large inputs, hot paths, or repeated calls).
|
||||
|
||||
## Checklist
|
||||
- Identify the **state** you're recomputing.
|
||||
- Add **memoization / caching** when the same subproblem repeats.
|
||||
- Prefer **linear scans** + caches over nested loops when possible.
|
||||
- If you can write it as a **recurrence**, you can test it.
|
||||
|
||||
## Practical heuristics
|
||||
- Measure first when possible (timing + input sizes).
|
||||
- Optimize the biggest wins: avoid repeated I/O, repeated parsing, repeated network calls.
|
||||
- Keep caches bounded (size/TTL) and invalidate safely.
|
||||
- Choose data structures intentionally: dict/set for membership, heap for top-k, deque for queues.
|
||||
|
||||
## Review notes (for PRs)
|
||||
- Call out accidental O(n²) patterns.
|
||||
- Suggest table/DP or memoization when repeated work is obvious.
|
||||
- Add tests that cover base cases + typical cases + worst-case size.
|
||||
31
SKILLS/26-vibe-coding-fundamentals.md
Normal file
31
SKILLS/26-vibe-coding-fundamentals.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Vibe Coding With Fundamentals (Safety Rails)
|
||||
|
||||
Use this skill when you're using "vibe coding" (fast, conversational building) but want production-grade outcomes.
|
||||
|
||||
## The good
|
||||
- Rapid scaffolding and iteration
|
||||
- Fast UI prototypes
|
||||
- Quick exploration of architectures and options
|
||||
|
||||
## The failure mode
|
||||
- "It works on my machine" code with weak tests
|
||||
- Security foot-guns (auth, input validation, secrets)
|
||||
- Performance cliffs (accidental O(n²), repeated I/O)
|
||||
- Unmaintainable abstractions
|
||||
|
||||
## Safety rails (apply every time)
|
||||
- Always start with acceptance criteria (what "done" means).
|
||||
- Prefer small PRs; never dump a huge AI diff.
|
||||
- Require DoD gates (lint/test/build) before merge.
|
||||
- Write tests for behavior changes.
|
||||
- For anything security/data related: do a Reviewer pass.
|
||||
|
||||
## When to slow down
|
||||
- Auth/session/token work
|
||||
- Anything touching payments, PII, secrets
|
||||
- Data migrations/schema changes
|
||||
- Performance-critical paths
|
||||
- "It's flaky" or "it only fails in CI"
|
||||
|
||||
## Practical prompt pattern (use in PLAN)
|
||||
- "State assumptions, list files to touch, propose tests, and include rollback steps."
|
||||
31
SKILLS/27-performance-profiling.md
Normal file
31
SKILLS/27-performance-profiling.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# Performance Profiling (Bun/Node)
|
||||
|
||||
Use this skill when:
|
||||
- a hot path feels slow
|
||||
- CPU usage is high
|
||||
- you suspect accidental O(n²) or repeated work
|
||||
- you need evidence before optimizing
|
||||
|
||||
## Bun CPU profiling
|
||||
Bun supports CPU profiling via `--cpu-prof` (generates a `.cpuprofile` you can open in Chrome DevTools).
|
||||
|
||||
Upcoming: `bun --cpu-prof-md <script>` outputs a CPU profile as **Markdown** so LLMs can read/grep it easily.
|
||||
|
||||
### Workflow (Bun)
|
||||
1) Run the workload with profiling enabled
|
||||
- Today: `bun --cpu-prof ./path/to/script.ts`
|
||||
- Upcoming: `bun --cpu-prof-md ./path/to/script.ts`
|
||||
2) Save the output (or `.cpuprofile`) into `./profiles/` with a timestamp.
|
||||
3) Ask the Reviewer agent to:
|
||||
- identify the top 5 hottest functions
|
||||
- propose the smallest fix
|
||||
- add a regression test or benchmark
|
||||
|
||||
## Node CPU profiling (fallback)
|
||||
- `node --cpu-prof ./script.js` writes a `.cpuprofile` file.
|
||||
- Open in Chrome DevTools → Performance → Load profile.
|
||||
|
||||
## Rules
|
||||
- Optimize based on measured hotspots, not vibes.
|
||||
- Prefer algorithmic wins (remove repeated work) over micro-optimizations.
|
||||
- Keep profiling artifacts out of git unless explicitly needed (use `.gitignore`).
|
||||
16
SKILLS/30-implementation-rules.md
Normal file
16
SKILLS/30-implementation-rules.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# Implementation Rules
|
||||
|
||||
## Change policy
|
||||
- Prefer edits over rewrites.
|
||||
- Keep changes localized.
|
||||
- One change = one purpose.
|
||||
- Avoid unnecessary abstraction.
|
||||
|
||||
## Dependency policy
|
||||
- Default: do not add dependencies.
|
||||
- If adding: explain why, alternatives considered, and impact.
|
||||
|
||||
## Error handling
|
||||
- Validate inputs at boundaries.
|
||||
- Error messages must be actionable: what failed + what to do next.
|
||||
14
SKILLS/40-testing-quality.md
Normal file
14
SKILLS/40-testing-quality.md
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
# Testing & Quality
|
||||
|
||||
## Strategy
|
||||
- If behavior changes: add/update tests.
|
||||
- Unit tests for logic; integration tests for boundaries; E2E only where needed.
|
||||
|
||||
## Minimum for every PR
|
||||
- A test plan in the PR summary (even if "existing tests cover this").
|
||||
- Run DoD.
|
||||
|
||||
## Flaky tests
|
||||
- Capture repro steps.
|
||||
- Quarantine only with justification + follow-up issue.
|
||||
16
SKILLS/50-pr-review.md
Normal file
16
SKILLS/50-pr-review.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
# PR Review Skill
|
||||
|
||||
Reviewer must check:
|
||||
- Correctness: does it do what it claims?
|
||||
- Safety: secrets, injection, auth boundaries
|
||||
- Maintainability: readability, naming, duplication
|
||||
- Tests: added/updated appropriately
|
||||
- DoD: did it pass?
|
||||
|
||||
Reviewer output format:
|
||||
1) Summary
|
||||
2) Must-fix
|
||||
3) Nice-to-have
|
||||
4) Risks
|
||||
5) Verification suggestions
|
||||
41
SKILLS/56-ui-material-ui.md
Normal file
41
SKILLS/56-ui-material-ui.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Material UI (MUI) Design System
|
||||
|
||||
Use this skill for any React/Next "portal/admin/dashboard" UI so you stay consistent and avoid random component soup.
|
||||
|
||||
## Standard choice
|
||||
- Preferred UI library: **MUI (Material UI)**.
|
||||
- Prefer MUI components over ad-hoc HTML/CSS unless there's a good reason.
|
||||
- One design system per repo (do not mix Chakra/Ant/Bootstrap/etc.).
|
||||
|
||||
## Setup (Next.js/React)
|
||||
- Install: `@mui/material @emotion/react @emotion/styled`
|
||||
- If using icons: `@mui/icons-material`
|
||||
- If using data grid: `@mui/x-data-grid` (or pro if licensed)
|
||||
|
||||
## Theming rules
|
||||
- Define a single theme (typography, spacing, palette) and reuse everywhere.
|
||||
- Use semantic colors (primary/secondary/error/warning/success/info), not hard-coded hex everywhere.
|
||||
- Prefer MUI's `sx` for small styling; use `styled()` for reusable components.
|
||||
|
||||
## "Portal" patterns (modals, popovers, menus)
|
||||
- Use MUI Dialog/Modal/Popover/Menu components instead of DIY portals.
|
||||
- Accessibility requirements:
|
||||
- Focus is trapped in Dialog/Modal.
|
||||
- Escape closes modal unless explicitly prevented.
|
||||
- All inputs have labels; buttons have clear text/aria-labels.
|
||||
- Keyboard navigation works end-to-end.
|
||||
|
||||
## Layout conventions (for portals)
|
||||
- Use: AppBar + Drawer (or NavigationRail equivalent) + main content.
|
||||
- Keep pages as composition of small components: Page → Sections → Widgets.
|
||||
- Keep forms consistent: FormControl + helper text + validation messages.
|
||||
|
||||
## Performance hygiene
|
||||
- Avoid re-render storms: memoize heavy lists; use virtualization for large tables (DataGrid).
|
||||
- Prefer server pagination for huge datasets.
|
||||
|
||||
## PR review checklist
|
||||
- Theme is used (no random styling).
|
||||
- Components are MUI where reasonable.
|
||||
- Modal/popover accessibility is correct.
|
||||
- No mixed UI libraries.
|
||||
15
SKILLS/60-security-safety.md
Normal file
15
SKILLS/60-security-safety.md
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
# Security & Safety
|
||||
|
||||
## Secrets
|
||||
- Never output secrets or tokens.
|
||||
- Never log sensitive inputs.
|
||||
- Never commit credentials.
|
||||
|
||||
## Inputs
|
||||
- Validate external inputs at boundaries.
|
||||
- Fail closed for auth/security decisions.
|
||||
|
||||
## Tooling
|
||||
- No destructive commands unless requested and scoped.
|
||||
- Prefer read-only operations first.
|
||||
13
SKILLS/70-docs-artifacts.md
Normal file
13
SKILLS/70-docs-artifacts.md
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
# Docs & Artifacts
|
||||
|
||||
Update documentation when:
|
||||
- setup steps change
|
||||
- env vars change
|
||||
- endpoints/CLI behavior changes
|
||||
- data formats change
|
||||
|
||||
Docs standards:
|
||||
- Provide copy/paste commands
|
||||
- Provide expected outputs where helpful
|
||||
- Keep it short and accurate
|
||||
11
SKILLS/80-mcp-tools.md
Normal file
11
SKILLS/80-mcp-tools.md
Normal file
@@ -0,0 +1,11 @@
|
||||
|
||||
# MCP Tools Skill (Optional)
|
||||
|
||||
If this repo defines MCP servers/tools:
|
||||
|
||||
Rules:
|
||||
- Tool calls must be explicit and logged.
|
||||
- Maintain an allowlist of tools; deny by default.
|
||||
- Every tool must have: purpose, inputs/outputs schema, examples, and tests.
|
||||
- Prefer idempotent tool operations.
|
||||
- Never add tools that can exfiltrate secrets without strict guards.
|
||||
51
SKILLS/82-mcp-server-design.md
Normal file
51
SKILLS/82-mcp-server-design.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# MCP Server Design (Agent-First)
|
||||
|
||||
Build MCP servers like you're designing a UI for a non-human user.
|
||||
|
||||
This skill distills Phil Schmid's MCP server best practices into concrete repo rules.
|
||||
Source: "MCP is Not the Problem, It's your Server" (Jan 21, 2026).
|
||||
|
||||
## 1) Outcomes, not operations
|
||||
- Do **not** wrap REST endpoints 1:1 as tools.
|
||||
- Expose high-level, outcome-oriented tools.
|
||||
- Bad: `get_user`, `list_orders`, `get_order_status`
|
||||
- Good: `track_latest_order(email)` (server orchestrates internally)
|
||||
|
||||
## 2) Flatten arguments
|
||||
- Prefer top-level primitives + constrained enums.
|
||||
- Avoid nested `dict`/config objects (agents hallucinate keys).
|
||||
- Defaults reduce decision load.
|
||||
|
||||
## 3) Instructions are context
|
||||
- Tool docstrings are *instructions*:
|
||||
- when to use the tool
|
||||
- argument formatting rules
|
||||
- what the return means
|
||||
- Error strings are also context:
|
||||
- return actionable, self-correcting messages (not raw stack traces)
|
||||
|
||||
## 4) Curate ruthlessly
|
||||
- Aim for **5–15 tools** per server.
|
||||
- One server, one job. Split by persona if needed.
|
||||
- Delete unused tools. Don't dump raw data into context.
|
||||
|
||||
## 5) Name tools for discovery
|
||||
- Avoid generic names (`create_issue`).
|
||||
- Prefer `{service}_{action}_{resource}`:
|
||||
- `velociraptor_run_hunt`
|
||||
- `github_list_prs`
|
||||
- `slack_send_message`
|
||||
|
||||
## 6) Paginate large results
|
||||
- Always support `limit` (default ~20–50).
|
||||
- Return metadata: `has_more`, `next_offset`, `total_count`.
|
||||
- Never return hundreds of rows unbounded.
|
||||
|
||||
## Repo conventions
|
||||
- Put MCP tool specs in `mcp/` (schemas, examples, fixtures).
|
||||
- Provide at least 1 "golden path" example call per tool.
|
||||
- Add an eval that checks:
|
||||
- tool names follow discovery convention
|
||||
- args are flat + typed
|
||||
- responses are concise + stable
|
||||
- pagination works
|
||||
40
SKILLS/83-fastmcp-3-patterns.md
Normal file
40
SKILLS/83-fastmcp-3-patterns.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# FastMCP 3 Patterns (Providers + Transforms)
|
||||
|
||||
Use this skill when you are building MCP servers in Python and want:
|
||||
- composable tool sets
|
||||
- per-user/per-session behavior
|
||||
- auth, versioning, observability, and long-running tasks
|
||||
|
||||
## Mental model (FastMCP 3)
|
||||
FastMCP 3 treats everything as three composable primitives:
|
||||
- **Components**: what you expose (tools, resources, prompts)
|
||||
- **Providers**: where components come from (decorators, files, OpenAPI, remote MCP, etc.)
|
||||
- **Transforms**: how you reshape what clients see (namespace, filters, auth, versioning, visibility)
|
||||
|
||||
## Recommended architecture for Marc's platform
|
||||
Build a **single "Cyber MCP Gateway"** that composes providers:
|
||||
- LocalProvider: core cyber tools (run hunt, parse triage, generate report)
|
||||
- OpenAPIProvider: wrap stable internal APIs (ticketing, asset DB) without 1:1 endpoint exposure
|
||||
- ProxyProvider/FastMCPProvider: mount sub-servers (e.g., Velociraptor tools, Intel feeds)
|
||||
|
||||
Then apply transforms:
|
||||
- Namespace per domain: `hunt.*`, `intel.*`, `pad.*`
|
||||
- Visibility per session: hide dangerous tools unless user/role allows
|
||||
- VersionFilter: keep old clients working while you evolve tools
|
||||
|
||||
## Production must-haves
|
||||
- **Tool timeouts**: never let a tool hang forever
|
||||
- **Pagination**: all list tools must be bounded
|
||||
- **Background tasks**: use for long hunts / ingest jobs
|
||||
- **Tracing**: emit OpenTelemetry traces so you can debug agent/tool behavior
|
||||
|
||||
## Auth rules
|
||||
- Prefer component-level auth for "dangerous" tools.
|
||||
- Default stance: read-only tools visible; write/execute tools gated.
|
||||
|
||||
## Versioning rules
|
||||
- Version your components when you change schemas or semantics.
|
||||
- Keep 1 previous version callable during migrations.
|
||||
|
||||
## Upgrade guidance
|
||||
FastMCP 3 is in beta; pin to v2 for stability in production until you've tested.
|
||||
149
backend/alembic.ini
Normal file
149
backend/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./threathunt.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
67
backend/alembic/env.py
Normal file
67
backend/alembic/env.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Alembic async env — autogenerate from app.db.models."""
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Alembic Config
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import all models so autogenerate sees them
|
||||
from app.db.engine import Base # noqa: E402
|
||||
from app.db import models as _models # noqa: E402, F401
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True, # required for SQLite ALTER TABLE
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with an async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
210
backend/alembic/versions/9790f482da06_initial_schema.py
Normal file
210
backend/alembic/versions/9790f482da06_initial_schema.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""initial schema
|
||||
|
||||
Revision ID: 9790f482da06
|
||||
Revises:
|
||||
Create Date: 2026-02-19 11:40:02.108830
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9790f482da06'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('username', sa.String(length=64), nullable=False),
|
||||
sa.Column('email', sa.String(length=256), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(length=256), nullable=False),
|
||||
sa.Column('role', sa.String(length=16), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_users_username'), ['username'], unique=True)
|
||||
|
||||
op.create_table('hunts',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=256), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('owner_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('datasets',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=256), nullable=False),
|
||||
sa.Column('filename', sa.String(length=512), nullable=False),
|
||||
sa.Column('source_tool', sa.String(length=64), nullable=True),
|
||||
sa.Column('row_count', sa.Integer(), nullable=False),
|
||||
sa.Column('column_schema', sa.JSON(), nullable=True),
|
||||
sa.Column('normalized_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('ioc_columns', sa.JSON(), nullable=True),
|
||||
sa.Column('file_size_bytes', sa.Integer(), nullable=False),
|
||||
sa.Column('encoding', sa.String(length=32), nullable=True),
|
||||
sa.Column('delimiter', sa.String(length=4), nullable=True),
|
||||
sa.Column('time_range_start', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('time_range_end', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('uploaded_by', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_datasets_hunt', ['hunt_id'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_datasets_name'), ['name'], unique=False)
|
||||
|
||||
op.create_table('hypotheses',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('title', sa.String(length=256), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('mitre_technique', sa.String(length=32), nullable=True),
|
||||
sa.Column('status', sa.String(length=16), nullable=False),
|
||||
sa.Column('evidence_row_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('evidence_notes', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('hypotheses', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_hypotheses_hunt', ['hunt_id'], unique=False)
|
||||
|
||||
op.create_table('conversations',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('title', sa.String(length=256), nullable=True),
|
||||
sa.Column('hunt_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.ForeignKeyConstraint(['hunt_id'], ['hunts.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('dataset_rows',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('row_index', sa.Integer(), nullable=False),
|
||||
sa.Column('data', sa.JSON(), nullable=False),
|
||||
sa.Column('normalized_data', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('dataset_rows', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_dataset_rows_dataset', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('ix_dataset_rows_dataset_idx', ['dataset_id', 'row_index'], unique=False)
|
||||
|
||||
op.create_table('enrichment_results',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('ioc_value', sa.String(length=512), nullable=False),
|
||||
sa.Column('ioc_type', sa.String(length=32), nullable=False),
|
||||
sa.Column('source', sa.String(length=32), nullable=False),
|
||||
sa.Column('verdict', sa.String(length=16), nullable=True),
|
||||
sa.Column('confidence', sa.Float(), nullable=True),
|
||||
sa.Column('raw_result', sa.JSON(), nullable=True),
|
||||
sa.Column('summary', sa.Text(), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('cached_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('enrichment_results', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_enrichment_ioc_source', ['ioc_value', 'source'], unique=False)
|
||||
batch_op.create_index(batch_op.f('ix_enrichment_results_ioc_value'), ['ioc_value'], unique=False)
|
||||
|
||||
op.create_table('annotations',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('row_id', sa.Integer(), nullable=True),
|
||||
sa.Column('dataset_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('author_id', sa.String(length=32), nullable=True),
|
||||
sa.Column('text', sa.Text(), nullable=False),
|
||||
sa.Column('severity', sa.String(length=16), nullable=False),
|
||||
sa.Column('tag', sa.String(length=32), nullable=True),
|
||||
sa.Column('highlight_color', sa.String(length=16), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['author_id'], ['users.id'], ),
|
||||
sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ),
|
||||
sa.ForeignKeyConstraint(['row_id'], ['dataset_rows.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('annotations', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_annotations_dataset', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('ix_annotations_row', ['row_id'], unique=False)
|
||||
|
||||
op.create_table('messages',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('conversation_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('role', sa.String(length=16), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('model_used', sa.String(length=128), nullable=True),
|
||||
sa.Column('node_used', sa.String(length=64), nullable=True),
|
||||
sa.Column('token_count', sa.Integer(), nullable=True),
|
||||
sa.Column('latency_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('response_meta', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['conversation_id'], ['conversations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_messages_conversation', ['conversation_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_messages_conversation')
|
||||
|
||||
op.drop_table('messages')
|
||||
with op.batch_alter_table('annotations', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_annotations_row')
|
||||
batch_op.drop_index('ix_annotations_dataset')
|
||||
|
||||
op.drop_table('annotations')
|
||||
with op.batch_alter_table('enrichment_results', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_enrichment_results_ioc_value'))
|
||||
batch_op.drop_index('ix_enrichment_ioc_source')
|
||||
|
||||
op.drop_table('enrichment_results')
|
||||
with op.batch_alter_table('dataset_rows', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_dataset_rows_dataset_idx')
|
||||
batch_op.drop_index('ix_dataset_rows_dataset')
|
||||
|
||||
op.drop_table('dataset_rows')
|
||||
op.drop_table('conversations')
|
||||
with op.batch_alter_table('hypotheses', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_hypotheses_hunt')
|
||||
|
||||
op.drop_table('hypotheses')
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_datasets_name'))
|
||||
batch_op.drop_index('ix_datasets_hunt')
|
||||
|
||||
op.drop_table('datasets')
|
||||
op.drop_table('hunts')
|
||||
with op.batch_alter_table('users', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_users_username'))
|
||||
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,64 @@
|
||||
"""add_keyword_themes_and_keywords_tables
|
||||
|
||||
Revision ID: 98ab619418bc
|
||||
Revises: 9790f482da06
|
||||
Create Date: 2026-02-19 12:01:38.174653
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '98ab619418bc'
|
||||
down_revision: Union[str, Sequence[str], None] = '9790f482da06'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('keyword_themes',
|
||||
sa.Column('id', sa.String(length=32), nullable=False),
|
||||
sa.Column('name', sa.String(length=128), nullable=False),
|
||||
sa.Column('color', sa.String(length=16), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_builtin', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('keyword_themes', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_keyword_themes_name'), ['name'], unique=True)
|
||||
|
||||
op.create_table('keywords',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('theme_id', sa.String(length=32), nullable=False),
|
||||
sa.Column('value', sa.String(length=256), nullable=False),
|
||||
sa.Column('is_regex', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['theme_id'], ['keyword_themes.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
with op.batch_alter_table('keywords', schema=None) as batch_op:
|
||||
batch_op.create_index('ix_keywords_theme', ['theme_id'], unique=False)
|
||||
batch_op.create_index('ix_keywords_value', ['value'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('keywords', schema=None) as batch_op:
|
||||
batch_op.drop_index('ix_keywords_value')
|
||||
batch_op.drop_index('ix_keywords_theme')
|
||||
|
||||
op.drop_table('keywords')
|
||||
with op.batch_alter_table('keyword_themes', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_keyword_themes_name'))
|
||||
|
||||
op.drop_table('keyword_themes')
|
||||
# ### end Alembic commands ###
|
||||
408
backend/app/agents/core_v2.py
Normal file
408
backend/app/agents/core_v2.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Core ThreatHunt analyst-assist agent — v2.
|
||||
|
||||
Uses TaskRouter to select the right model/node for each query,
|
||||
real LLM providers (Ollama/OpenWebUI), and structured response parsing.
|
||||
Integrates SANS RAG context from Open WebUI.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
from app.services.sans_rag import sans_rag
|
||||
from .router import TaskRouter, TaskType, RoutingDecision, task_router
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(..., description="Analyst question or request for guidance")
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type")
|
||||
host_identifier: Optional[str] = Field(None, description="Host name, IP, or identifier")
|
||||
data_summary: Optional[str] = Field(None, description="Brief description of data")
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages"
|
||||
)
|
||||
active_hypotheses: Optional[list[str]] = Field(
|
||||
default_factory=list, description="Active investigation hypotheses"
|
||||
)
|
||||
annotations_summary: Optional[str] = Field(
|
||||
None, description="Summary of analyst annotations"
|
||||
)
|
||||
enrichment_summary: Optional[str] = Field(
|
||||
None, description="Summary of enrichment results"
|
||||
)
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: Optional[str] = Field(None, description="Force a specific model")
|
||||
|
||||
|
||||
class Perspective(BaseModel):
|
||||
"""A single perspective from the debate agent."""
|
||||
role: str
|
||||
content: str
|
||||
model_used: str
|
||||
node_used: str
|
||||
latency_ms: int
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence (0-1)")
|
||||
suggested_pivots: list[str] = Field(default_factory=list)
|
||||
suggested_filters: list[str] = Field(default_factory=list)
|
||||
caveats: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
sans_references: list[str] = Field(
|
||||
default_factory=list, description="SANS course references"
|
||||
)
|
||||
model_used: str = Field(default="", description="Model that generated the response")
|
||||
node_used: str = Field(default="", description="Node that processed the request")
|
||||
latency_ms: int = Field(default=0, description="Total latency in ms")
|
||||
perspectives: Optional[list[Perspective]] = Field(
|
||||
None, description="Debate perspectives (only in debate mode)"
|
||||
)
|
||||
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────
|
||||
|
||||
SYSTEM_PROMPT = """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
You have access to 300GB of SANS cybersecurity course material for reference.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor and other forensic tools
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Reference relevant SANS methodologies and techniques when applicable
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You explain your reasoning transparently
|
||||
|
||||
RESPONSE FORMAT — you MUST respond with valid JSON:
|
||||
{
|
||||
"guidance": "Your main guidance text here",
|
||||
"confidence": 0.85,
|
||||
"suggested_pivots": ["Pivot 1", "Pivot 2"],
|
||||
"suggested_filters": ["filter expression 1", "filter expression 2"],
|
||||
"caveats": "Any assumptions or limitations",
|
||||
"reasoning": "How you arrived at this guidance",
|
||||
"sans_references": ["SANS SEC504: ...", "SANS FOR508: ..."]
|
||||
}
|
||||
|
||||
Respond ONLY with the JSON object. No markdown, no code fences, no extra text."""
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent backed by Wile + Roadrunner LLM cluster."""
|
||||
|
||||
def __init__(self, router: TaskRouter | None = None):
|
||||
self.router = router or task_router
|
||||
self.system_prompt = SYSTEM_PROMPT
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis."""
|
||||
start = time.monotonic()
|
||||
|
||||
if context.mode == "debate":
|
||||
return await self._debate_assist(context)
|
||||
|
||||
# Classify task and route
|
||||
task_type = self.router.classify_task(context.query)
|
||||
if context.mode == "deep":
|
||||
task_type = TaskType.DEEP_ANALYSIS
|
||||
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
logger.info(f"Routing: {decision.reason}")
|
||||
|
||||
# Enrich prompt with SANS RAG context
|
||||
prompt = self._build_prompt(context)
|
||||
try:
|
||||
rag_context = await sans_rag.enrich_prompt(
|
||||
context.query,
|
||||
investigation_context=context.data_summary or "",
|
||||
)
|
||||
if rag_context:
|
||||
prompt = f"{prompt}\n\n{rag_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"SANS RAG enrichment failed: {e}")
|
||||
|
||||
# Call LLM
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
result = await provider.chat(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
)
|
||||
|
||||
raw_text = result.get("response", "")
|
||||
latency_ms = result.get("_latency_ms", 0)
|
||||
|
||||
# Parse structured response
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = decision.model
|
||||
response.node_used = decision.node.value
|
||||
response.latency_ms = latency_ms
|
||||
|
||||
total_ms = int((time.monotonic() - start) * 1000)
|
||||
logger.info(
|
||||
f"Agent assist: {context.query[:60]}... → "
|
||||
f"{decision.model} on {decision.node.value} "
|
||||
f"({total_ms}ms total, {latency_ms}ms LLM)"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def assist_stream(
|
||||
self,
|
||||
context: AgentContext,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream agent response tokens."""
|
||||
task_type = self.router.classify_task(context.query)
|
||||
decision = self.router.route(task_type, model_override=context.model_override)
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
provider = self.router.get_provider(decision)
|
||||
if isinstance(provider, OllamaProvider):
|
||||
async for token in provider.generate_stream(
|
||||
prompt,
|
||||
system=self.system_prompt,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
elif isinstance(provider, OpenWebUIProvider):
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
async for token in provider.chat_stream(
|
||||
messages,
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=settings.AGENT_TEMPERATURE,
|
||||
):
|
||||
yield token
|
||||
|
||||
async def _debate_assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Multi-perspective analysis using diverse models on Wile."""
|
||||
import asyncio
|
||||
|
||||
start = time.monotonic()
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
# Route each perspective to a different heavy model
|
||||
roles = {
|
||||
TaskType.DEBATE_PLANNER: (
|
||||
"Planner",
|
||||
"You are the Planner for a threat hunting investigation.\n"
|
||||
"Provide a structured investigation strategy. Reference SANS methodologies.\n"
|
||||
"Focus on: investigation steps, data sources to examine, MITRE ATT&CK mapping.\n"
|
||||
"Be specific to the data context provided.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_CRITIC: (
|
||||
"Critic",
|
||||
"You are the Critic for a threat hunting investigation.\n"
|
||||
"Identify risks, false positive scenarios, missing evidence, and assumptions.\n"
|
||||
"Reference SANS training on common analyst mistakes.\n"
|
||||
"Challenge the obvious interpretation.\n\n",
|
||||
),
|
||||
TaskType.DEBATE_PRAGMATIST: (
|
||||
"Pragmatist",
|
||||
"You are the Pragmatist for a threat hunting investigation.\n"
|
||||
"Suggest the most actionable, efficient next steps.\n"
|
||||
"Reference SANS incident response playbooks.\n"
|
||||
"Focus on: quick wins, triage priorities, what to escalate.\n\n",
|
||||
),
|
||||
}
|
||||
|
||||
async def _call_perspective(task_type: TaskType, role_name: str, prefix: str):
|
||||
decision = self.router.route(task_type)
|
||||
provider = self.router.get_provider(decision)
|
||||
full_prompt = prefix + prompt
|
||||
|
||||
if isinstance(provider, OpenWebUIProvider):
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
else:
|
||||
result = await provider.generate(
|
||||
full_prompt,
|
||||
system=f"You are the {role_name}. Provide analysis only. No execution.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.4,
|
||||
)
|
||||
|
||||
return Perspective(
|
||||
role=role_name,
|
||||
content=result.get("response", ""),
|
||||
model_used=decision.model,
|
||||
node_used=decision.node.value,
|
||||
latency_ms=result.get("_latency_ms", 0),
|
||||
)
|
||||
|
||||
# Run perspectives in parallel
|
||||
perspective_tasks = [
|
||||
_call_perspective(tt, name, prefix)
|
||||
for tt, (name, prefix) in roles.items()
|
||||
]
|
||||
perspectives = await asyncio.gather(*perspective_tasks)
|
||||
|
||||
# Judge merges the perspectives
|
||||
judge_prompt = (
|
||||
"You are the Judge. Merge these three threat hunting perspectives into "
|
||||
"ONE final advisory answer.\n\n"
|
||||
"Rules:\n"
|
||||
"- Advisory only — no execution\n"
|
||||
"- Clearly list risks and assumptions\n"
|
||||
"- Highlight where perspectives agree and disagree\n"
|
||||
"- Provide a unified recommendation\n"
|
||||
"- Reference SANS methodologies where relevant\n\n"
|
||||
)
|
||||
for p in perspectives:
|
||||
judge_prompt += f"=== {p.role} (via {p.model_used}) ===\n{p.content}\n\n"
|
||||
|
||||
judge_prompt += (
|
||||
f"\nOriginal analyst query:\n{context.query}\n\n"
|
||||
"Respond with the merged analysis in this JSON format:\n"
|
||||
'{"guidance": "...", "confidence": 0.85, "suggested_pivots": [...], '
|
||||
'"suggested_filters": [...], "caveats": "...", "reasoning": "...", '
|
||||
'"sans_references": [...]}'
|
||||
)
|
||||
|
||||
judge_decision = self.router.route(TaskType.DEBATE_JUDGE)
|
||||
judge_provider = self.router.get_provider(judge_decision)
|
||||
|
||||
if isinstance(judge_provider, OpenWebUIProvider):
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
else:
|
||||
judge_result = await judge_provider.generate(
|
||||
judge_prompt,
|
||||
system="You are the Judge. Merge perspectives into a final advisory answer. Respond with JSON only.",
|
||||
max_tokens=settings.AGENT_MAX_TOKENS,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
raw_text = judge_result.get("response", "")
|
||||
response = self._parse_response(raw_text, context)
|
||||
response.model_used = judge_decision.model
|
||||
response.node_used = judge_decision.node.value
|
||||
response.latency_ms = int((time.monotonic() - start) * 1000)
|
||||
response.perspectives = list(perspectives)
|
||||
|
||||
return response
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt with all available context."""
|
||||
parts = [f"Analyst query: {context.query}"]
|
||||
|
||||
if context.dataset_name:
|
||||
parts.append(f"Dataset: {context.dataset_name}")
|
||||
if context.artifact_type:
|
||||
parts.append(f"Artifact type: {context.artifact_type}")
|
||||
if context.host_identifier:
|
||||
parts.append(f"Host: {context.host_identifier}")
|
||||
if context.data_summary:
|
||||
parts.append(f"Data summary: {context.data_summary}")
|
||||
if context.active_hypotheses:
|
||||
parts.append(f"Active hypotheses: {'; '.join(context.active_hypotheses)}")
|
||||
if context.annotations_summary:
|
||||
parts.append(f"Analyst annotations: {context.annotations_summary}")
|
||||
if context.enrichment_summary:
|
||||
parts.append(f"Enrichment data: {context.enrichment_summary}")
|
||||
if context.conversation_history:
|
||||
parts.append("\nRecent conversation:")
|
||||
for msg in context.conversation_history[-settings.AGENT_HISTORY_LENGTH:]:
|
||||
parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')[:500]}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _parse_response(self, raw: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM output into structured AgentResponse.
|
||||
|
||||
Tries JSON extraction first, falls back to raw text with defaults.
|
||||
"""
|
||||
parsed = self._try_parse_json(raw)
|
||||
if parsed:
|
||||
return AgentResponse(
|
||||
guidance=parsed.get("guidance", raw),
|
||||
confidence=min(max(float(parsed.get("confidence", 0.7)), 0.0), 1.0),
|
||||
suggested_pivots=parsed.get("suggested_pivots", [])[:6],
|
||||
suggested_filters=parsed.get("suggested_filters", [])[:6],
|
||||
caveats=parsed.get("caveats"),
|
||||
reasoning=parsed.get("reasoning"),
|
||||
sans_references=parsed.get("sans_references", []),
|
||||
)
|
||||
|
||||
# Fallback: use raw text as guidance
|
||||
return AgentResponse(
|
||||
guidance=raw.strip() or "No guidance generated. Please try rephrasing your question.",
|
||||
confidence=0.5,
|
||||
suggested_pivots=[],
|
||||
suggested_filters=[],
|
||||
caveats="Response was not in structured format. Pivots and filters may be embedded in the guidance text.",
|
||||
reasoning=None,
|
||||
sans_references=[],
|
||||
)
|
||||
|
||||
def _try_parse_json(self, text: str) -> dict | None:
|
||||
"""Try to extract JSON from LLM output."""
|
||||
# Direct parse
|
||||
try:
|
||||
return json.loads(text.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Extract from code fences
|
||||
patterns = [
|
||||
r"```json\s*(.*?)\s*```",
|
||||
r"```\s*(.*?)\s*```",
|
||||
r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1) if match.lastindex else match.group(0))
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
continue
|
||||
|
||||
return None
|
||||
362
backend/app/agents/providers_v2.py
Normal file
362
backend/app/agents/providers_v2.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
|
||||
|
||||
Three providers:
|
||||
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
|
||||
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
|
||||
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from .registry import ModelEntry, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared HTTP client with reasonable timeouts
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None or _client.is_closed:
|
||||
_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
|
||||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def cleanup_client():
|
||||
global _client
|
||||
if _client and not _client.is_closed:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
|
||||
|
||||
def _ollama_url(node: Node) -> str:
|
||||
"""Get the Ollama base URL for a node."""
|
||||
if node == Node.WILE:
|
||||
return settings.wile_url
|
||||
elif node == Node.ROADRUNNER:
|
||||
return settings.roadrunner_url
|
||||
else:
|
||||
raise ValueError(f"No direct Ollama URL for node: {node}")
|
||||
|
||||
|
||||
# ── Ollama Provider ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class OllamaProvider:
|
||||
"""Direct Ollama API calls to Wile or Roadrunner."""
|
||||
|
||||
def __init__(self, model: str, node: Node):
|
||||
self.model = model
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
data["_latency_ms"] = latency_ms
|
||||
data["_node"] = self.node.value
|
||||
logger.info(
|
||||
f"Ollama [{self.node.value}] {self.model}: "
|
||||
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
|
||||
)
|
||||
return data
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
|
||||
raise
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via Ollama /api/chat."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
|
||||
data["_node"] = self.node.value
|
||||
return data
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from Ollama."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with client.stream(
|
||||
"POST", f"{self.base_url}/api/generate", json=payload
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Ping the Ollama node."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
|
||||
|
||||
|
||||
class OpenWebUIProvider:
|
||||
"""Calls to Open WebUI cluster at ai.guapo613.beer.
|
||||
|
||||
Uses the OpenAI-compatible /v1/chat/completions endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = ""):
|
||||
self.model = model or settings.DEFAULT_FAST_MODEL
|
||||
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via OpenAI-compatible endpoint."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Normalize to our format
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
result = {
|
||||
"response": content,
|
||||
"model": data.get("model", self.model),
|
||||
"_latency_ms": latency_ms,
|
||||
"_node": "cluster",
|
||||
"_usage": data.get("usage", {}),
|
||||
}
|
||||
logger.info(
|
||||
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
|
||||
)
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Convert prompt-style call to chat format."""
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
return await self.chat(messages, max_tokens, temperature)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from OpenWebUI."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
if token:
|
||||
yield token
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Open WebUI is reachable."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Embedding Provider ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Generate embeddings via Ollama /api/embeddings."""
|
||||
|
||||
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
|
||||
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
client = _get_client()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("embedding", [])
|
||||
|
||||
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
|
||||
"""Embed multiple texts with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _embed_one(t: str) -> list[float]:
|
||||
async with sem:
|
||||
return await self.embed(t)
|
||||
|
||||
return await asyncio.gather(*[_embed_one(t) for t in texts])
|
||||
|
||||
|
||||
# ── Health check for all nodes ────────────────────────────────────────
|
||||
|
||||
|
||||
async def check_all_nodes() -> dict:
|
||||
"""Check availability of all LLM nodes."""
|
||||
wile = OllamaProvider("", Node.WILE)
|
||||
roadrunner = OllamaProvider("", Node.ROADRUNNER)
|
||||
cluster = OpenWebUIProvider()
|
||||
|
||||
wile_ok, rr_ok, cl_ok = await asyncio.gather(
|
||||
wile.is_available(),
|
||||
roadrunner.is_available(),
|
||||
cluster.is_available(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"wile": {"available": wile_ok is True, "url": settings.wile_url},
|
||||
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
|
||||
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
|
||||
}
|
||||
161
backend/app/agents/registry.py
Normal file
161
backend/app/agents/registry.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Model registry — inventory of all Ollama models across Wile and Roadrunner.
|
||||
|
||||
Each model is tagged with capabilities (chat, code, vision, embedding) and
|
||||
performance tier (fast, medium, heavy) for the TaskRouter.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Capability(str, Enum):
|
||||
CHAT = "chat"
|
||||
CODE = "code"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class Tier(str, Enum):
|
||||
FAST = "fast" # < 15B params — quick responses
|
||||
MEDIUM = "medium" # 15–40B params — balanced
|
||||
HEAVY = "heavy" # 40B+ params — deep analysis
|
||||
|
||||
|
||||
class Node(str, Enum):
|
||||
WILE = "wile"
|
||||
ROADRUNNER = "roadrunner"
|
||||
CLUSTER = "cluster" # Open WebUI balances across both
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelEntry:
|
||||
name: str
|
||||
node: Node
|
||||
capabilities: list[Capability]
|
||||
tier: Tier
|
||||
param_size: str = "" # e.g. "7b", "70b"
|
||||
notes: str = ""
|
||||
|
||||
|
||||
# ── Roadrunner (100.110.190.11) ──────────────────────────────────────
|
||||
|
||||
ROADRUNNER_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("qwen2.5:14b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "14b"),
|
||||
ModelEntry("mistral:7b-instruct", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("mistral:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5:7b", Node.ROADRUNNER, [Capability.CHAT], Tier.FAST, "7b"),
|
||||
ModelEntry("phi3:medium", Node.ROADRUNNER, [Capability.CHAT], Tier.MEDIUM, "14b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("codestral:latest", Node.ROADRUNNER, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
ModelEntry("codellama:13b", Node.ROADRUNNER, [Capability.CODE], Tier.FAST, "13b"),
|
||||
# Vision
|
||||
ModelEntry("llama3.2-vision:11b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "11b"),
|
||||
ModelEntry("minicpm-v:latest", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "8b"),
|
||||
ModelEntry("llava:13b", Node.ROADRUNNER, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
ModelEntry("nomic-embed-text:latest", Node.ROADRUNNER, [Capability.EMBEDDING], Tier.FAST, "0.1b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.ROADRUNNER, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
]
|
||||
|
||||
# ── Wile (100.110.190.12) ────────────────────────────────────────────
|
||||
|
||||
WILE_MODELS: list[ModelEntry] = [
|
||||
# General / chat
|
||||
ModelEntry("llama3.1:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("llama3:latest", Node.WILE, [Capability.CHAT], Tier.FAST, "8b"),
|
||||
ModelEntry("gemma2:27b", Node.WILE, [Capability.CHAT], Tier.MEDIUM, "27b"),
|
||||
# Code
|
||||
ModelEntry("qwen2.5-coder:7b", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:latest", Node.WILE, [Capability.CODE], Tier.FAST, "7b"),
|
||||
ModelEntry("qwen2.5-coder:32b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "32b"),
|
||||
ModelEntry("deepseek-coder:33b", Node.WILE, [Capability.CODE], Tier.MEDIUM, "33b"),
|
||||
ModelEntry("codestral:latest", Node.WILE, [Capability.CODE], Tier.MEDIUM, "22b"),
|
||||
# Vision
|
||||
ModelEntry("llava:13b", Node.WILE, [Capability.VISION], Tier.FAST, "13b"),
|
||||
# Embeddings
|
||||
ModelEntry("bge-m3:latest", Node.WILE, [Capability.EMBEDDING], Tier.FAST, "0.6b"),
|
||||
# Heavy
|
||||
ModelEntry("llama3.1:70b", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q4_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("llama3.1:70b-instruct-q5_K_M", Node.WILE, [Capability.CHAT], Tier.HEAVY, "70b"),
|
||||
ModelEntry("mixtral:8x22b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "141b"),
|
||||
ModelEntry("qwen2:72b-instruct", Node.WILE, [Capability.CHAT], Tier.HEAVY, "72b"),
|
||||
]
|
||||
|
||||
ALL_MODELS = ROADRUNNER_MODELS + WILE_MODELS
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Registry of all available models and their capabilities."""
|
||||
|
||||
def __init__(self, models: list[ModelEntry] | None = None):
|
||||
self.models = models or ALL_MODELS
|
||||
self._by_name: dict[str, list[ModelEntry]] = {}
|
||||
self._by_capability: dict[Capability, list[ModelEntry]] = {}
|
||||
self._by_node: dict[Node, list[ModelEntry]] = {}
|
||||
self._index()
|
||||
|
||||
def _index(self):
|
||||
for m in self.models:
|
||||
self._by_name.setdefault(m.name, []).append(m)
|
||||
for cap in m.capabilities:
|
||||
self._by_capability.setdefault(cap, []).append(m)
|
||||
self._by_node.setdefault(m.node, []).append(m)
|
||||
|
||||
def find(
|
||||
self,
|
||||
capability: Capability | None = None,
|
||||
tier: Tier | None = None,
|
||||
node: Node | None = None,
|
||||
) -> list[ModelEntry]:
|
||||
"""Find models matching all given criteria."""
|
||||
results = list(self.models)
|
||||
if capability:
|
||||
results = [m for m in results if capability in m.capabilities]
|
||||
if tier:
|
||||
results = [m for m in results if m.tier == tier]
|
||||
if node:
|
||||
results = [m for m in results if m.node == node]
|
||||
return results
|
||||
|
||||
def get_best(
|
||||
self,
|
||||
capability: Capability,
|
||||
prefer_tier: Tier | None = None,
|
||||
prefer_node: Node | None = None,
|
||||
) -> ModelEntry | None:
|
||||
"""Get the best model for a capability, with optional preference."""
|
||||
candidates = self.find(capability=capability, tier=prefer_tier, node=prefer_node)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability, tier=prefer_tier)
|
||||
if not candidates:
|
||||
candidates = self.find(capability=capability)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
def list_nodes(self) -> list[Node]:
|
||||
return list(self._by_node.keys())
|
||||
|
||||
def list_models_on_node(self, node: Node) -> list[ModelEntry]:
|
||||
return self._by_node.get(node, [])
|
||||
|
||||
def to_dict(self) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"name": m.name,
|
||||
"node": m.node.value,
|
||||
"capabilities": [c.value for c in m.capabilities],
|
||||
"tier": m.tier.value,
|
||||
"param_size": m.param_size,
|
||||
}
|
||||
for m in self.models
|
||||
]
|
||||
|
||||
|
||||
# Singleton
|
||||
registry = ModelRegistry()
|
||||
183
backend/app/agents/router.py
Normal file
183
backend/app/agents/router.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Task router — auto-selects the right model + node for each task type.
|
||||
|
||||
Routes based on task characteristics:
|
||||
- Quick chat → fast models via cluster
|
||||
- Deep analysis → 70B+ models on Wile
|
||||
- Code/script analysis → code models (32b on Wile, 7b for quick)
|
||||
- Vision/image → vision models on Roadrunner
|
||||
- Embedding → embedding models on either node
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.config import settings
|
||||
from .registry import Capability, Tier, Node, ModelEntry, registry
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
QUICK_CHAT = "quick_chat"
|
||||
DEEP_ANALYSIS = "deep_analysis"
|
||||
CODE_ANALYSIS = "code_analysis"
|
||||
VISION = "vision"
|
||||
EMBEDDING = "embedding"
|
||||
DEBATE_PLANNER = "debate_planner"
|
||||
DEBATE_CRITIC = "debate_critic"
|
||||
DEBATE_PRAGMATIST = "debate_pragmatist"
|
||||
DEBATE_JUDGE = "debate_judge"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of the routing decision."""
|
||||
model: str
|
||||
node: Node
|
||||
task_type: TaskType
|
||||
provider_type: str # "ollama" or "openwebui"
|
||||
reason: str
|
||||
|
||||
|
||||
class TaskRouter:
|
||||
"""Routes tasks to the appropriate model and node."""
|
||||
|
||||
# Default routing rules: task_type → (capability, preferred_tier, preferred_node)
|
||||
ROUTING_RULES: dict[TaskType, tuple[Capability, Tier | None, Node | None]] = {
|
||||
TaskType.QUICK_CHAT: (Capability.CHAT, Tier.FAST, None),
|
||||
TaskType.DEEP_ANALYSIS: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.CODE_ANALYSIS: (Capability.CODE, Tier.MEDIUM, Node.WILE),
|
||||
TaskType.VISION: (Capability.VISION, None, Node.ROADRUNNER),
|
||||
TaskType.EMBEDDING: (Capability.EMBEDDING, Tier.FAST, None),
|
||||
TaskType.DEBATE_PLANNER: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_CRITIC: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_PRAGMATIST: (Capability.CHAT, Tier.HEAVY, Node.WILE),
|
||||
TaskType.DEBATE_JUDGE: (Capability.CHAT, Tier.MEDIUM, Node.WILE),
|
||||
}
|
||||
|
||||
# Specific model overrides for debate roles (use diverse models for diversity of thought)
|
||||
DEBATE_MODEL_OVERRIDES: dict[TaskType, str] = {
|
||||
TaskType.DEBATE_PLANNER: "llama3.1:70b-instruct-q4_K_M",
|
||||
TaskType.DEBATE_CRITIC: "qwen2:72b-instruct",
|
||||
TaskType.DEBATE_PRAGMATIST: "mixtral:8x22b-instruct",
|
||||
TaskType.DEBATE_JUDGE: "gemma2:27b",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.registry = registry
|
||||
|
||||
def route(self, task_type: TaskType, model_override: str | None = None) -> RoutingDecision:
|
||||
"""Decide which model and node to use for a task."""
|
||||
|
||||
# Explicit model override
|
||||
if model_override:
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_override:
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Explicit model override: {model_override}",
|
||||
)
|
||||
# Model not in registry — try via cluster
|
||||
return RoutingDecision(
|
||||
model=model_override,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"Override model {model_override} not in registry, routing to cluster",
|
||||
)
|
||||
|
||||
# Debate model overrides
|
||||
if task_type in self.DEBATE_MODEL_OVERRIDES:
|
||||
model_name = self.DEBATE_MODEL_OVERRIDES[task_type]
|
||||
entries = self.registry.find()
|
||||
for entry in entries:
|
||||
if entry.name == model_name:
|
||||
return RoutingDecision(
|
||||
model=model_name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Debate role {task_type.value} → {model_name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Standard routing
|
||||
cap, tier, node = self.ROUTING_RULES.get(
|
||||
task_type,
|
||||
(Capability.CHAT, Tier.FAST, None),
|
||||
)
|
||||
|
||||
entry = self.registry.get_best(cap, prefer_tier=tier, prefer_node=node)
|
||||
if entry:
|
||||
return RoutingDecision(
|
||||
model=entry.name,
|
||||
node=entry.node,
|
||||
task_type=task_type,
|
||||
provider_type="ollama",
|
||||
reason=f"Auto-routed {task_type.value}: {cap.value}/{tier.value if tier else 'any'} → {entry.name} on {entry.node.value}",
|
||||
)
|
||||
|
||||
# Fallback to cluster
|
||||
default_model = settings.DEFAULT_FAST_MODEL
|
||||
return RoutingDecision(
|
||||
model=default_model,
|
||||
node=Node.CLUSTER,
|
||||
task_type=task_type,
|
||||
provider_type="openwebui",
|
||||
reason=f"No registry match, falling back to cluster with {default_model}",
|
||||
)
|
||||
|
||||
def get_provider(self, decision: RoutingDecision):
|
||||
"""Create the appropriate provider for a routing decision."""
|
||||
if decision.provider_type == "openwebui":
|
||||
return OpenWebUIProvider(model=decision.model)
|
||||
else:
|
||||
return OllamaProvider(model=decision.model, node=decision.node)
|
||||
|
||||
def get_embedding_provider(self, model: str | None = None, node: Node | None = None) -> EmbeddingProvider:
|
||||
"""Get an embedding provider."""
|
||||
return EmbeddingProvider(
|
||||
model=model or settings.DEFAULT_EMBEDDING_MODEL,
|
||||
node=node or Node.ROADRUNNER,
|
||||
)
|
||||
|
||||
def classify_task(self, query: str, has_image: bool = False) -> TaskType:
|
||||
"""Heuristic classification of query into task type.
|
||||
|
||||
In practice this could be enhanced by a classifier model, but
|
||||
keyword heuristics work well for routing.
|
||||
"""
|
||||
if has_image:
|
||||
return TaskType.VISION
|
||||
|
||||
q = query.lower()
|
||||
|
||||
# Code/script indicators
|
||||
code_indicators = [
|
||||
"deobfuscate", "decode", "powershell", "script", "base64",
|
||||
"command line", "cmdline", "commandline", "obfuscated",
|
||||
"malware", "shellcode", "vbs", "vbscript", "batch",
|
||||
"python script", "code review", "reverse engineer",
|
||||
]
|
||||
if any(ind in q for ind in code_indicators):
|
||||
return TaskType.CODE_ANALYSIS
|
||||
|
||||
# Deep analysis indicators
|
||||
deep_indicators = [
|
||||
"deep analysis", "detailed", "comprehensive", "thorough",
|
||||
"investigate", "root cause", "advanced", "explain in detail",
|
||||
"full analysis", "forensic",
|
||||
]
|
||||
if any(ind in q for ind in deep_indicators):
|
||||
return TaskType.DEEP_ANALYSIS
|
||||
|
||||
return TaskType.QUICK_CHAT
|
||||
|
||||
|
||||
# Singleton
|
||||
task_router = TaskRouter()
|
||||
265
backend/app/api/routes/agent_v2.py
Normal file
265
backend/app/api/routes/agent_v2.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
global _agent
|
||||
if _agent is None:
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
query: str = Field(..., max_length=4000, description="Analyst question")
|
||||
dataset_name: str | None = None
|
||||
artifact_type: str | None = None
|
||||
host_identifier: str | None = None
|
||||
data_summary: str | None = None
|
||||
conversation_history: list[dict] | None = None
|
||||
active_hypotheses: list[str] | None = None
|
||||
annotations_summary: str | None = None
|
||||
enrichment_summary: str | None = None
|
||||
mode: str = Field(default="quick", description="quick | deep | debate")
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
sans_references: list[str] = []
|
||||
model_used: str = ""
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponseModel,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance with auto-routed model selection. "
|
||||
"Supports quick (fast), deep (70B), and debate (multi-model) modes.",
|
||||
)
|
||||
async def agent_assist(
|
||||
request: AssistRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
active_hypotheses=request.active_hypotheses or [],
|
||||
annotations_summary=request.annotations_summary,
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
|
||||
# Persist conversation
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db, conv_id, request, response
|
||||
)
|
||||
|
||||
return AssistResponseModel(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
perspectives=[
|
||||
{
|
||||
"role": p.role,
|
||||
"content": p.content,
|
||||
"model_used": p.model_used,
|
||||
"node_used": p.node_used,
|
||||
"latency_ms": p.latency_ms,
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Agent error: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Agent error: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist/stream",
|
||||
summary="Stream agent response",
|
||||
description="Stream tokens via SSE for real-time display.",
|
||||
)
|
||||
async def agent_assist_stream(request: AssistRequest):
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
mode="quick", # streaming only supports quick mode
|
||||
)
|
||||
|
||||
async def _stream():
|
||||
async for token in agent.assist_stream(context):
|
||||
yield f"data: {json.dumps({'token': token})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent and node health",
|
||||
description="Returns availability of all LLM nodes and the cluster.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
nodes = await check_all_nodes()
|
||||
rag_health = await sans_rag.health_check()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"nodes": nodes,
|
||||
"rag": rag_health,
|
||||
"default_models": {
|
||||
"fast": settings.DEFAULT_FAST_MODEL,
|
||||
"heavy": settings.DEFAULT_HEAVY_MODEL,
|
||||
"code": settings.DEFAULT_CODE_MODEL,
|
||||
"vision": settings.DEFAULT_VISION_MODEL,
|
||||
"embedding": settings.DEFAULT_EMBEDDING_MODEL,
|
||||
},
|
||||
"config": {
|
||||
"max_tokens": settings.AGENT_MAX_TOKENS,
|
||||
"temperature": settings.AGENT_TEMPERATURE,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List all available models",
|
||||
description="Returns the full model registry with capabilities and node assignments.",
|
||||
)
|
||||
async def list_models():
|
||||
return {
|
||||
"models": registry.to_dict(),
|
||||
"total": len(registry.models),
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
db: AsyncSession,
|
||||
conversation_id: str | None,
|
||||
request: AssistRequest,
|
||||
response: AgentResponse,
|
||||
) -> str:
|
||||
"""Save user message and agent response to the database."""
|
||||
if conversation_id:
|
||||
# Find existing conversation
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(
|
||||
select(Conversation).where(Conversation.id == conversation_id)
|
||||
)
|
||||
conv = result.scalar_one_or_none()
|
||||
if not conv:
|
||||
conv = Conversation(id=conversation_id, hunt_id=request.hunt_id)
|
||||
db.add(conv)
|
||||
else:
|
||||
conv = Conversation(
|
||||
title=request.query[:100],
|
||||
hunt_id=request.hunt_id,
|
||||
)
|
||||
db.add(conv)
|
||||
await db.flush()
|
||||
|
||||
# User message
|
||||
user_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="user",
|
||||
content=request.query,
|
||||
)
|
||||
db.add(user_msg)
|
||||
|
||||
# Agent message
|
||||
agent_msg = Message(
|
||||
conversation_id=conv.id,
|
||||
role="agent",
|
||||
content=response.guidance,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
response_meta={
|
||||
"confidence": response.confidence,
|
||||
"pivots": response.suggested_pivots,
|
||||
"filters": response.suggested_filters,
|
||||
"sans_refs": response.sans_references,
|
||||
},
|
||||
)
|
||||
db.add(agent_msg)
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
311
backend/app/api/routes/annotations.py
Normal file
311
backend/app/api/routes/annotations.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""API routes for annotations and hypotheses."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Annotation, Hypothesis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["annotations"])
|
||||
|
||||
|
||||
# ── Annotation models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AnnotationCreate(BaseModel):
|
||||
row_id: int | None = None
|
||||
dataset_id: str | None = None
|
||||
text: str = Field(..., max_length=2000)
|
||||
severity: str = Field(default="info") # info|low|medium|high|critical
|
||||
tag: str | None = None # suspicious|benign|needs-review
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationUpdate(BaseModel):
|
||||
text: str | None = None
|
||||
severity: str | None = None
|
||||
tag: str | None = None
|
||||
highlight_color: str | None = None
|
||||
|
||||
|
||||
class AnnotationResponse(BaseModel):
|
||||
id: str
|
||||
row_id: int | None
|
||||
dataset_id: str | None
|
||||
author_id: str | None
|
||||
text: str
|
||||
severity: str
|
||||
tag: str | None
|
||||
highlight_color: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class AnnotationListResponse(BaseModel):
|
||||
annotations: list[AnnotationResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Hypothesis models ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HypothesisCreate(BaseModel):
|
||||
hunt_id: str | None = None
|
||||
title: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str = Field(default="draft")
|
||||
|
||||
|
||||
class HypothesisUpdate(BaseModel):
|
||||
title: str | None = None
|
||||
description: str | None = None
|
||||
mitre_technique: str | None = None
|
||||
status: str | None = None # draft|active|confirmed|rejected
|
||||
evidence_row_ids: list[int] | None = None
|
||||
evidence_notes: str | None = None
|
||||
|
||||
|
||||
class HypothesisResponse(BaseModel):
|
||||
id: str
|
||||
hunt_id: str | None
|
||||
title: str
|
||||
description: str | None
|
||||
mitre_technique: str | None
|
||||
status: str
|
||||
evidence_row_ids: list | None
|
||||
evidence_notes: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class HypothesisListResponse(BaseModel):
|
||||
hypotheses: list[HypothesisResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Annotation routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
ann_router = APIRouter(prefix="/api/annotations")
|
||||
|
||||
|
||||
@ann_router.post("", response_model=AnnotationResponse, summary="Create annotation")
|
||||
async def create_annotation(body: AnnotationCreate, db: AsyncSession = Depends(get_db)):
|
||||
ann = Annotation(
|
||||
row_id=body.row_id,
|
||||
dataset_id=body.dataset_id,
|
||||
text=body.text,
|
||||
severity=body.severity,
|
||||
tag=body.tag,
|
||||
highlight_color=body.highlight_color,
|
||||
)
|
||||
db.add(ann)
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.get("", response_model=AnnotationListResponse, summary="List annotations")
|
||||
async def list_annotations(
|
||||
dataset_id: str | None = Query(None),
|
||||
row_id: int | None = Query(None),
|
||||
tag: str | None = Query(None),
|
||||
severity: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Annotation).order_by(Annotation.created_at.desc())
|
||||
if dataset_id:
|
||||
stmt = stmt.where(Annotation.dataset_id == dataset_id)
|
||||
if row_id:
|
||||
stmt = stmt.where(Annotation.row_id == row_id)
|
||||
if tag:
|
||||
stmt = stmt.where(Annotation.tag == tag)
|
||||
if severity:
|
||||
stmt = stmt.where(Annotation.severity == severity)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
annotations = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Annotation.id))
|
||||
if dataset_id:
|
||||
count_stmt = count_stmt.where(Annotation.dataset_id == dataset_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return AnnotationListResponse(
|
||||
annotations=[
|
||||
AnnotationResponse(
|
||||
id=a.id, row_id=a.row_id, dataset_id=a.dataset_id,
|
||||
author_id=a.author_id, text=a.text, severity=a.severity,
|
||||
tag=a.tag, highlight_color=a.highlight_color,
|
||||
created_at=a.created_at.isoformat(), updated_at=a.updated_at.isoformat(),
|
||||
)
|
||||
for a in annotations
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@ann_router.put("/{annotation_id}", response_model=AnnotationResponse, summary="Update annotation")
|
||||
async def update_annotation(
|
||||
annotation_id: str, body: AnnotationUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
if body.text is not None:
|
||||
ann.text = body.text
|
||||
if body.severity is not None:
|
||||
ann.severity = body.severity
|
||||
if body.tag is not None:
|
||||
ann.tag = body.tag
|
||||
if body.highlight_color is not None:
|
||||
ann.highlight_color = body.highlight_color
|
||||
await db.flush()
|
||||
return AnnotationResponse(
|
||||
id=ann.id, row_id=ann.row_id, dataset_id=ann.dataset_id,
|
||||
author_id=ann.author_id, text=ann.text, severity=ann.severity,
|
||||
tag=ann.tag, highlight_color=ann.highlight_color,
|
||||
created_at=ann.created_at.isoformat(), updated_at=ann.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@ann_router.delete("/{annotation_id}", summary="Delete annotation")
|
||||
async def delete_annotation(annotation_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Annotation).where(Annotation.id == annotation_id))
|
||||
ann = result.scalar_one_or_none()
|
||||
if not ann:
|
||||
raise HTTPException(status_code=404, detail="Annotation not found")
|
||||
await db.delete(ann)
|
||||
return {"message": "Annotation deleted", "id": annotation_id}
|
||||
|
||||
|
||||
# ── Hypothesis routes ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
hyp_router = APIRouter(prefix="/api/hypotheses")
|
||||
|
||||
|
||||
@hyp_router.post("", response_model=HypothesisResponse, summary="Create hypothesis")
|
||||
async def create_hypothesis(body: HypothesisCreate, db: AsyncSession = Depends(get_db)):
|
||||
hyp = Hypothesis(
|
||||
hunt_id=body.hunt_id,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
mitre_technique=body.mitre_technique,
|
||||
status=body.status,
|
||||
)
|
||||
db.add(hyp)
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("", response_model=HypothesisListResponse, summary="List hypotheses")
|
||||
async def list_hypotheses(
|
||||
hunt_id: str | None = Query(None),
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hypothesis).order_by(Hypothesis.updated_at.desc())
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
if status:
|
||||
stmt = stmt.where(Hypothesis.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hyps = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hypothesis.id))
|
||||
if hunt_id:
|
||||
count_stmt = count_stmt.where(Hypothesis.hunt_id == hunt_id)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HypothesisListResponse(
|
||||
hypotheses=[
|
||||
HypothesisResponse(
|
||||
id=h.id, hunt_id=h.hunt_id, title=h.title,
|
||||
description=h.description, mitre_technique=h.mitre_technique,
|
||||
status=h.status, evidence_row_ids=h.evidence_row_ids,
|
||||
evidence_notes=h.evidence_notes,
|
||||
created_at=h.created_at.isoformat(), updated_at=h.updated_at.isoformat(),
|
||||
)
|
||||
for h in hyps
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.get("/{hypothesis_id}", response_model=HypothesisResponse, summary="Get hypothesis")
|
||||
async def get_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.put("/{hypothesis_id}", response_model=HypothesisResponse, summary="Update hypothesis")
|
||||
async def update_hypothesis(
|
||||
hypothesis_id: str, body: HypothesisUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
if body.title is not None:
|
||||
hyp.title = body.title
|
||||
if body.description is not None:
|
||||
hyp.description = body.description
|
||||
if body.mitre_technique is not None:
|
||||
hyp.mitre_technique = body.mitre_technique
|
||||
if body.status is not None:
|
||||
hyp.status = body.status
|
||||
if body.evidence_row_ids is not None:
|
||||
hyp.evidence_row_ids = body.evidence_row_ids
|
||||
if body.evidence_notes is not None:
|
||||
hyp.evidence_notes = body.evidence_notes
|
||||
await db.flush()
|
||||
return HypothesisResponse(
|
||||
id=hyp.id, hunt_id=hyp.hunt_id, title=hyp.title,
|
||||
description=hyp.description, mitre_technique=hyp.mitre_technique,
|
||||
status=hyp.status, evidence_row_ids=hyp.evidence_row_ids,
|
||||
evidence_notes=hyp.evidence_notes,
|
||||
created_at=hyp.created_at.isoformat(), updated_at=hyp.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@hyp_router.delete("/{hypothesis_id}", summary="Delete hypothesis")
|
||||
async def delete_hypothesis(hypothesis_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hypothesis).where(Hypothesis.id == hypothesis_id))
|
||||
hyp = result.scalar_one_or_none()
|
||||
if not hyp:
|
||||
raise HTTPException(status_code=404, detail="Hypothesis not found")
|
||||
await db.delete(hyp)
|
||||
return {"message": "Hypothesis deleted", "id": hypothesis_id}
|
||||
197
backend/app/api/routes/auth.py
Normal file
197
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
from app.services.auth import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_token_pair,
|
||||
decode_token,
|
||||
get_current_user,
|
||||
TokenPair,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str = Field(..., min_length=3, max_length=64)
|
||||
email: str = Field(..., max_length=256)
|
||||
password: str = Field(..., min_length=8, max_length=128)
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
display_name: str | None
|
||||
role: str
|
||||
is_active: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
user: UserResponse
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=AuthResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Register a new user",
|
||||
)
|
||||
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
# Check for existing username
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Username already taken",
|
||||
)
|
||||
|
||||
# Check for existing email
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
if result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Email already registered",
|
||||
)
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
logger.info(f"New user registered: {user.username} ({user.id})")
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=AuthResponse,
|
||||
summary="Login with username and password",
|
||||
)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Account is disabled",
|
||||
)
|
||||
|
||||
tokens = create_token_pair(user.id, user.role)
|
||||
|
||||
return AuthResponse(
|
||||
user=UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat(),
|
||||
),
|
||||
tokens=tokens,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=TokenPair,
|
||||
summary="Refresh access token",
|
||||
)
|
||||
async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
||||
token_data = decode_token(body.refresh_token)
|
||||
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user",
|
||||
)
|
||||
|
||||
return create_token_pair(user.id, user.role)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=UserResponse,
|
||||
summary="Get current user profile",
|
||||
)
|
||||
async def get_profile(user: User = Depends(get_current_user)):
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
display_name=user.display_name,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
83
backend/app/api/routes/correlation.py
Normal file
83
backend/app/api/routes/correlation.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""API routes for cross-hunt correlation analysis."""
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.correlation import correlation_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/correlation", tags=["correlation"])
|
||||
|
||||
|
||||
class CorrelateRequest(BaseModel):
|
||||
hunt_ids: list[str] = Field(
|
||||
...,
|
||||
min_length=2,
|
||||
max_length=20,
|
||||
description="List of hunt IDs to correlate",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/analyze",
|
||||
summary="Run correlation analysis across hunts",
|
||||
description="Find shared IOCs, overlapping time windows, common MITRE techniques, "
|
||||
"and host patterns across the specified hunts.",
|
||||
)
|
||||
async def correlate_hunts(
|
||||
body: CorrelateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await correlation_engine.correlate_hunts(body.hunt_ids, db)
|
||||
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps],
|
||||
"host_overlaps": result.host_overlaps,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/all",
|
||||
summary="Correlate all hunts",
|
||||
description="Run correlation across all hunts in the system.",
|
||||
)
|
||||
async def correlate_all(db: AsyncSession = Depends(get_db)):
|
||||
result = await correlation_engine.correlate_all(db)
|
||||
return {
|
||||
"hunt_ids": result.hunt_ids,
|
||||
"summary": result.summary,
|
||||
"total_correlations": result.total_correlations,
|
||||
"ioc_overlaps": [asdict(o) for o in result.ioc_overlaps[:20]],
|
||||
"time_overlaps": [asdict(o) for o in result.time_overlaps[:10]],
|
||||
"technique_overlaps": [asdict(o) for o in result.technique_overlaps[:10]],
|
||||
"host_overlaps": result.host_overlaps[:10],
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ioc/{ioc_value}",
|
||||
summary="Find IOC across all hunts",
|
||||
description="Search for a specific IOC value across all datasets and hunts.",
|
||||
)
|
||||
async def find_ioc(
|
||||
ioc_value: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
occurrences = await correlation_engine.find_ioc_across_hunts(ioc_value, db)
|
||||
return {
|
||||
"ioc_value": ioc_value,
|
||||
"occurrences": occurrences,
|
||||
"total": len(occurrences),
|
||||
"unique_hunts": len(set(o["hunt_id"] for o in occurrences if o.get("hunt_id"))),
|
||||
}
|
||||
295
backend/app/api/routes/datasets.py
Normal file
295
backend/app/api/routes/datasets.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""API routes for dataset upload, listing, and management."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
normalize_columns,
|
||||
normalize_rows,
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
filename: str
|
||||
source_tool: str | None = None
|
||||
row_count: int
|
||||
column_schema: dict | None = None
|
||||
normalized_columns: dict | None = None
|
||||
ioc_columns: dict | None = None
|
||||
file_size_bytes: int
|
||||
encoding: str | None = None
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
|
||||
class DatasetListResponse(BaseModel):
|
||||
datasets: list[DatasetSummary]
|
||||
total: int
|
||||
|
||||
|
||||
class RowsResponse(BaseModel):
|
||||
rows: list[dict]
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
row_count: int
|
||||
columns: list[str]
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload",
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
name: str | None = Query(None, description="Display name for the dataset"),
|
||||
source_tool: str | None = Query(None, description="Source tool (e.g., velociraptor)"),
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File type '{ext}' not allowed. Accepted: {', '.join(ALLOWED_EXTENSIONS)}",
|
||||
)
|
||||
|
||||
# Read file bytes
|
||||
raw_bytes = await file.read()
|
||||
if len(raw_bytes) == 0:
|
||||
raise HTTPException(status_code=400, detail="File is empty")
|
||||
|
||||
if len(raw_bytes) > settings.max_upload_bytes:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Max size: {settings.MAX_UPLOAD_SIZE_MB} MB",
|
||||
)
|
||||
|
||||
# Parse CSV
|
||||
try:
|
||||
rows, metadata = parse_csv_bytes(raw_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"CSV parse error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to parse CSV: {str(e)}. Check encoding and format.",
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise HTTPException(status_code=422, detail="CSV file contains no data rows")
|
||||
|
||||
columns: list[str] = metadata["columns"]
|
||||
column_types: dict = metadata["column_types"]
|
||||
|
||||
# Normalize columns
|
||||
column_mapping = normalize_columns(columns)
|
||||
normalized = normalize_rows(rows, column_mapping)
|
||||
|
||||
# Detect IOCs
|
||||
ioc_columns = detect_ioc_columns(columns, column_types, column_mapping)
|
||||
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
filename=file.filename,
|
||||
source_tool=source_tool,
|
||||
row_count=len(rows),
|
||||
column_schema=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
file_size_bytes=len(raw_bytes),
|
||||
encoding=metadata["encoding"],
|
||||
delimiter=metadata["delimiter"],
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
dataset_id=dataset.id,
|
||||
rows=rows,
|
||||
normalized_rows=normalized,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
)
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
row_count=len(rows),
|
||||
columns=columns,
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=DatasetListResponse,
|
||||
summary="List datasets",
|
||||
)
|
||||
async def list_datasets(
|
||||
hunt_id: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
datasets = await repo.list_datasets(hunt_id=hunt_id, limit=limit, offset=offset)
|
||||
total = await repo.count_datasets(hunt_id=hunt_id)
|
||||
|
||||
return DatasetListResponse(
|
||||
datasets=[
|
||||
DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
for ds in datasets
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}",
|
||||
response_model=DatasetSummary,
|
||||
summary="Get dataset details",
|
||||
)
|
||||
async def get_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return DatasetSummary(
|
||||
id=ds.id,
|
||||
name=ds.name,
|
||||
filename=ds.filename,
|
||||
source_tool=ds.source_tool,
|
||||
row_count=ds.row_count,
|
||||
column_schema=ds.column_schema,
|
||||
normalized_columns=ds.normalized_columns,
|
||||
ioc_columns=ds.ioc_columns,
|
||||
file_size_bytes=ds.file_size_bytes,
|
||||
encoding=ds.encoding,
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{dataset_id}/rows",
|
||||
response_model=RowsResponse,
|
||||
summary="Get dataset rows",
|
||||
)
|
||||
async def get_dataset_rows(
|
||||
dataset_id: str,
|
||||
limit: int = Query(1000, ge=1, le=10000),
|
||||
offset: int = Query(0, ge=0),
|
||||
normalized: bool = Query(False, description="Return normalized column names"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
ds = await repo.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=limit, offset=offset)
|
||||
total = await repo.count_rows(dataset_id)
|
||||
|
||||
return RowsResponse(
|
||||
rows=[
|
||||
(r.normalized_data if normalized and r.normalized_data else r.data)
|
||||
for r in rows
|
||||
],
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{dataset_id}",
|
||||
summary="Delete a dataset",
|
||||
)
|
||||
async def delete_dataset(
|
||||
dataset_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
repo = DatasetRepository(db)
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
220
backend/app/api/routes/enrichment.py
Normal file
220
backend/app/api/routes/enrichment.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""API routes for IOC enrichment."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.enrichment import (
|
||||
enrichment_engine,
|
||||
IOCType,
|
||||
Verdict,
|
||||
EnrichmentResultData,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/enrichment", tags=["enrichment"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichIOCRequest(BaseModel):
|
||||
ioc_value: str = Field(..., max_length=2048, description="IOC value to enrich")
|
||||
ioc_type: str = Field(..., description="IOC type: ip, domain, hash_md5, hash_sha1, hash_sha256, url")
|
||||
skip_cache: bool = False
|
||||
|
||||
|
||||
class EnrichBatchRequest(BaseModel):
|
||||
iocs: list[dict] = Field(
|
||||
...,
|
||||
description="List of {value, type} pairs",
|
||||
max_length=50,
|
||||
)
|
||||
|
||||
|
||||
class EnrichmentResultResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
source: str
|
||||
verdict: str
|
||||
score: float
|
||||
tags: list[str] = []
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
raw_data: dict = {}
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
class EnrichIOCResponse(BaseModel):
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
results: list[EnrichmentResultResponse]
|
||||
overall_verdict: str
|
||||
overall_score: float
|
||||
|
||||
|
||||
class EnrichBatchResponse(BaseModel):
|
||||
results: dict[str, list[EnrichmentResultResponse]]
|
||||
total_enriched: int
|
||||
|
||||
|
||||
def _to_response(r: EnrichmentResultData) -> EnrichmentResultResponse:
|
||||
return EnrichmentResultResponse(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
last_seen=r.last_seen,
|
||||
raw_data=r.raw_data,
|
||||
error=r.error,
|
||||
latency_ms=r.latency_ms,
|
||||
)
|
||||
|
||||
|
||||
def _compute_overall(results: list[EnrichmentResultData]) -> tuple[str, float]:
|
||||
"""Compute overall verdict from multiple provider results."""
|
||||
if not results:
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
verdicts = [r.verdict for r in results if r.verdict != Verdict.ERROR]
|
||||
if not verdicts:
|
||||
return Verdict.ERROR.value, 0.0
|
||||
|
||||
if Verdict.MALICIOUS in verdicts:
|
||||
return Verdict.MALICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.SUSPICIOUS in verdicts:
|
||||
return Verdict.SUSPICIOUS.value, max(r.score for r in results)
|
||||
elif Verdict.CLEAN in verdicts:
|
||||
return Verdict.CLEAN.value, 0.0
|
||||
return Verdict.UNKNOWN.value, 0.0
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ioc",
|
||||
response_model=EnrichIOCResponse,
|
||||
summary="Enrich a single IOC",
|
||||
description="Query all configured providers for an IOC (IP, hash, domain, URL).",
|
||||
)
|
||||
async def enrich_ioc(
|
||||
body: EnrichIOCRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
ioc_type = IOCType(body.ioc_type)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid IOC type: {body.ioc_type}. Valid: {[t.value for t in IOCType]}",
|
||||
)
|
||||
|
||||
results = await enrichment_engine.enrich_ioc(
|
||||
body.ioc_value, ioc_type, db=db, skip_cache=body.skip_cache,
|
||||
)
|
||||
|
||||
overall_verdict, overall_score = _compute_overall(results)
|
||||
|
||||
return EnrichIOCResponse(
|
||||
ioc_value=body.ioc_value,
|
||||
ioc_type=body.ioc_type,
|
||||
results=[_to_response(r) for r in results],
|
||||
overall_verdict=overall_verdict,
|
||||
overall_score=overall_score,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/batch",
|
||||
response_model=EnrichBatchResponse,
|
||||
summary="Enrich a batch of IOCs",
|
||||
description="Enrich up to 50 IOCs at once across all providers.",
|
||||
)
|
||||
async def enrich_batch(
|
||||
body: EnrichBatchRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
iocs = []
|
||||
for item in body.iocs:
|
||||
try:
|
||||
ioc_type = IOCType(item["type"])
|
||||
iocs.append((item["value"], ioc_type))
|
||||
except (KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not iocs:
|
||||
raise HTTPException(status_code=400, detail="No valid IOCs provided")
|
||||
|
||||
all_results = await enrichment_engine.enrich_batch(iocs, db=db)
|
||||
|
||||
return EnrichBatchResponse(
|
||||
results={
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
total_enriched=len(all_results),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/dataset/{dataset_id}",
|
||||
summary="Auto-enrich IOCs in a dataset",
|
||||
description="Automatically extract and enrich IOCs from a dataset's IOC columns.",
|
||||
)
|
||||
async def enrich_dataset(
|
||||
dataset_id: str,
|
||||
max_iocs: int = Query(50, ge=1, le=200),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
if not dataset.ioc_columns:
|
||||
return {"message": "No IOC columns detected in this dataset", "results": {}}
|
||||
|
||||
rows = await repo.get_rows(dataset_id, limit=1000)
|
||||
row_data = [r.data for r in rows]
|
||||
|
||||
all_results = await enrichment_engine.enrich_dataset_iocs(
|
||||
rows=row_data,
|
||||
ioc_columns=dataset.ioc_columns,
|
||||
db=db,
|
||||
max_iocs=max_iocs,
|
||||
)
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": dataset.name,
|
||||
"ioc_columns": dataset.ioc_columns,
|
||||
"results": {
|
||||
k: [_to_response(r) for r in v]
|
||||
for k, v in all_results.items()
|
||||
},
|
||||
"total_enriched": len(all_results),
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
summary="Enrichment engine status",
|
||||
description="Check which providers are configured and available.",
|
||||
)
|
||||
async def enrichment_status():
|
||||
return enrichment_engine.status()
|
||||
158
backend/app/api/routes/hunts.py
Normal file
158
backend/app/api/routes/hunts.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""API routes for hunt management."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
status: str
|
||||
owner_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
dataset_count: int = 0
|
||||
hypothesis_count: int = 0
|
||||
|
||||
|
||||
class HuntListResponse(BaseModel):
|
||||
hunts: list[HuntResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
async def create_hunt(body: HuntCreate, db: AsyncSession = Depends(get_db)):
|
||||
hunt = Hunt(name=body.name, description=body.description)
|
||||
db.add(hunt)
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=HuntListResponse, summary="List hunts")
|
||||
async def list_hunts(
|
||||
status: str | None = Query(None),
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(Hunt).order_by(Hunt.updated_at.desc())
|
||||
if status:
|
||||
stmt = stmt.where(Hunt.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await db.execute(stmt)
|
||||
hunts = result.scalars().all()
|
||||
|
||||
count_stmt = select(func.count(Hunt.id))
|
||||
if status:
|
||||
count_stmt = count_stmt.where(Hunt.status == status)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
return HuntListResponse(
|
||||
hunts=[
|
||||
HuntResponse(
|
||||
id=h.id,
|
||||
name=h.name,
|
||||
description=h.description,
|
||||
status=h.status,
|
||||
owner_id=h.owner_id,
|
||||
created_at=h.created_at.isoformat(),
|
||||
updated_at=h.updated_at.isoformat(),
|
||||
dataset_count=len(h.datasets) if h.datasets else 0,
|
||||
hypothesis_count=len(h.hypotheses) if h.hypotheses else 0,
|
||||
)
|
||||
for h in hunts
|
||||
],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}", response_model=HuntResponse, summary="Get hunt details")
|
||||
async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
dataset_count=len(hunt.datasets) if hunt.datasets else 0,
|
||||
hypothesis_count=len(hunt.hypotheses) if hunt.hypotheses else 0,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
if body.name is not None:
|
||||
hunt.name = body.name
|
||||
if body.description is not None:
|
||||
hunt.description = body.description
|
||||
if body.status is not None:
|
||||
hunt.status = body.status
|
||||
await db.flush()
|
||||
return HuntResponse(
|
||||
id=hunt.id,
|
||||
name=hunt.name,
|
||||
description=hunt.description,
|
||||
status=hunt.status,
|
||||
owner_id=hunt.owner_id,
|
||||
created_at=hunt.created_at.isoformat(),
|
||||
updated_at=hunt.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{hunt_id}", summary="Delete a hunt")
|
||||
async def delete_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
await db.delete(hunt)
|
||||
return {"message": "Hunt deleted", "id": hunt_id}
|
||||
257
backend/app/api/routes/keywords.py
Normal file
257
backend/app/api/routes/keywords.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ThemeUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
color: str | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class KeywordOut(BaseModel):
|
||||
id: int
|
||||
theme_id: str
|
||||
value: str
|
||||
is_regex: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
class ThemeOut(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
color: str
|
||||
enabled: bool
|
||||
is_builtin: bool
|
||||
created_at: str
|
||||
keyword_count: int
|
||||
keywords: list[KeywordOut]
|
||||
|
||||
|
||||
class ThemeListResponse(BaseModel):
|
||||
themes: list[ThemeOut]
|
||||
total: int
|
||||
|
||||
|
||||
class KeywordCreate(BaseModel):
|
||||
value: str = Field(..., min_length=1, max_length=256)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KeywordBulkCreate(BaseModel):
|
||||
values: list[str] = Field(..., min_items=1)
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
total_hits: int
|
||||
hits: list[ScanHit]
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
return ThemeOut(
|
||||
id=t.id,
|
||||
name=t.name,
|
||||
color=t.color,
|
||||
enabled=t.enabled,
|
||||
is_builtin=t.is_builtin,
|
||||
created_at=t.created_at.isoformat(),
|
||||
keyword_count=len(t.keywords),
|
||||
keywords=[
|
||||
KeywordOut(
|
||||
id=k.id,
|
||||
theme_id=k.theme_id,
|
||||
value=k.value,
|
||||
is_regex=k.is_regex,
|
||||
created_at=k.created_at.isoformat(),
|
||||
)
|
||||
for k in t.keywords
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme.name = body.name
|
||||
if body.color is not None:
|
||||
theme.color = body.color
|
||||
if body.enabled is not None:
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
kw = Keyword(theme_id=theme_id, value=body.value, is_regex=body.is_regex)
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
added = 0
|
||||
for val in body.values:
|
||||
val = val.strip()
|
||||
if not val:
|
||||
continue
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
scan_hunts=body.scan_hunts,
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
67
backend/app/api/routes/reports.py
Normal file
67
backend/app/api/routes/reports.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""API routes for report generation and export."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import HTMLResponse, PlainTextResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.reports import report_generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/reports", tags=["reports"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}",
|
||||
summary="Generate hunt investigation report",
|
||||
description="Generate a comprehensive report for a hunt. Supports JSON, HTML, and CSV formats.",
|
||||
)
|
||||
async def generate_hunt_report(
|
||||
hunt_id: str,
|
||||
format: str = Query("json", description="Report format: json, html, csv"),
|
||||
include_rows: bool = Query(False, description="Include raw data rows"),
|
||||
max_rows: int = Query(500, ge=0, le=5000, description="Max rows to include"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format=format,
|
||||
include_rows=include_rows, max_rows=max_rows,
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
if format == "html":
|
||||
return HTMLResponse(content=result, headers={
|
||||
"Content-Disposition": f"inline; filename=threathunt_report_{hunt_id}.html",
|
||||
})
|
||||
elif format == "csv":
|
||||
return PlainTextResponse(content=result, media_type="text/csv", headers={
|
||||
"Content-Disposition": f"attachment; filename=threathunt_report_{hunt_id}.csv",
|
||||
})
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/hunt/{hunt_id}/summary",
|
||||
summary="Quick hunt summary",
|
||||
description="Get a lightweight summary of the hunt for dashboard display.",
|
||||
)
|
||||
async def hunt_summary(
|
||||
hunt_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await report_generator.generate_hunt_report(
|
||||
hunt_id, db, format="json", include_rows=False,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
raise HTTPException(status_code=404, detail=result["error"])
|
||||
|
||||
return {
|
||||
"hunt": result.get("hunt"),
|
||||
"summary": result.get("summary"),
|
||||
}
|
||||
121
backend/app/config.py
Normal file
121
backend/app/config.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Application configuration — single source of truth for all settings.
|
||||
|
||||
Loads from environment variables with sensible defaults for local dev.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class AppConfig(BaseSettings):
|
||||
"""Central configuration for the entire ThreatHunt application."""
|
||||
|
||||
# ── General ────────────────────────────────────────────────────────
|
||||
APP_NAME: str = "ThreatHunt"
|
||||
APP_VERSION: str = "0.3.0"
|
||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# ── Database ───────────────────────────────────────────────────────
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite+aiosqlite:///./threathunt.db",
|
||||
description="Async SQLAlchemy database URL. "
|
||||
"Use sqlite+aiosqlite:///./threathunt.db for local dev, "
|
||||
"postgresql+asyncpg://user:pass@host/db for production.",
|
||||
)
|
||||
|
||||
# ── CORS ───────────────────────────────────────────────────────────
|
||||
ALLOWED_ORIGINS: str = Field(
|
||||
default="http://localhost:3000,http://localhost:8000",
|
||||
description="Comma-separated list of allowed CORS origins",
|
||||
)
|
||||
|
||||
# ── File uploads ───────────────────────────────────────────────────
|
||||
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
||||
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
||||
|
||||
# ── LLM Cluster — Wile & Roadrunner ────────────────────────────────
|
||||
OPENWEBUI_URL: str = Field(
|
||||
default="https://ai.guapo613.beer",
|
||||
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
||||
)
|
||||
OPENWEBUI_API_KEY: str = Field(
|
||||
default="",
|
||||
description="API key for Open WebUI (if required)",
|
||||
)
|
||||
WILE_HOST: str = Field(
|
||||
default="100.110.190.12",
|
||||
description="Tailscale IP for Wile (heavy models)",
|
||||
)
|
||||
WILE_OLLAMA_PORT: int = Field(default=11434, description="Ollama port on Wile")
|
||||
ROADRUNNER_HOST: str = Field(
|
||||
default="100.110.190.11",
|
||||
description="Tailscale IP for Roadrunner (fast models + vision)",
|
||||
)
|
||||
ROADRUNNER_OLLAMA_PORT: int = Field(
|
||||
default=11434, description="Ollama port on Roadrunner"
|
||||
)
|
||||
|
||||
# ── LLM Routing defaults ──────────────────────────────────────────
|
||||
DEFAULT_FAST_MODEL: str = Field(
|
||||
default="llama3.1:latest",
|
||||
description="Default model for quick chat / simple queries",
|
||||
)
|
||||
DEFAULT_HEAVY_MODEL: str = Field(
|
||||
default="llama3.1:70b-instruct-q4_K_M",
|
||||
description="Default model for deep analysis / debate",
|
||||
)
|
||||
DEFAULT_CODE_MODEL: str = Field(
|
||||
default="qwen2.5-coder:32b",
|
||||
description="Default model for code / script analysis",
|
||||
)
|
||||
DEFAULT_VISION_MODEL: str = Field(
|
||||
default="llama3.2-vision:11b",
|
||||
description="Default model for image / screenshot analysis",
|
||||
)
|
||||
DEFAULT_EMBEDDING_MODEL: str = Field(
|
||||
default="bge-m3:latest",
|
||||
description="Default embedding model",
|
||||
)
|
||||
|
||||
# ── Agent behaviour ───────────────────────────────────────────────
|
||||
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
||||
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
||||
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
||||
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
||||
|
||||
# ── Enrichment API keys ───────────────────────────────────────────
|
||||
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
||||
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
||||
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────
|
||||
JWT_SECRET: str = Field(
|
||||
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
||||
description="Secret for JWT signing",
|
||||
)
|
||||
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
||||
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
||||
|
||||
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
@property
|
||||
def cors_origins(self) -> list[str]:
|
||||
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",") if o.strip()]
|
||||
|
||||
@property
|
||||
def wile_url(self) -> str:
|
||||
return f"http://{self.WILE_HOST}:{self.WILE_OLLAMA_PORT}"
|
||||
|
||||
@property
|
||||
def roadrunner_url(self) -> str:
|
||||
return f"http://{self.ROADRUNNER_HOST}:{self.ROADRUNNER_OLLAMA_PORT}"
|
||||
|
||||
@property
|
||||
def max_upload_bytes(self) -> int:
|
||||
return self.MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
|
||||
settings = AppConfig()
|
||||
12
backend/app/db/__init__.py
Normal file
12
backend/app/db/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Database package."""
|
||||
|
||||
from .engine import Base, get_db, init_db, dispose_db, engine, async_session_factory
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"get_db",
|
||||
"init_db",
|
||||
"dispose_db",
|
||||
"engine",
|
||||
"async_session_factory",
|
||||
]
|
||||
54
backend/app/db/engine.py
Normal file
54
backend/app/db/engine.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
Uses async SQLAlchemy with aiosqlite for local dev and asyncpg for production PostgreSQL.
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
echo=settings.DEBUG,
|
||||
future=True,
|
||||
)
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession: # type: ignore[misc]
|
||||
"""FastAPI dependency that yields an async DB session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""Create all tables (for dev / first-run). In production use Alembic."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
await engine.dispose()
|
||||
328
backend/app/db/models.py
Normal file
328
backend/app/db/models.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, and users.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
Index,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .engine import Base
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
|
||||
# ── Users ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(String(256), unique=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst") # analyst | admin | viewer
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(back_populates="author", lazy="selectin")
|
||||
|
||||
|
||||
# ── Hunts ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Hunt(Base):
|
||||
__tablename__ = "hunts"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="active") # active | closed | archived
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
owner: Mapped[Optional["User"]] = relationship(back_populates="hunts", lazy="selectin")
|
||||
datasets: Mapped[list["Dataset"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
conversations: Mapped[list["Conversation"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
hypotheses: Mapped[list["Hypothesis"]] = relationship(back_populates="hunt", lazy="selectin")
|
||||
|
||||
|
||||
# ── Datasets ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
filename: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
source_tool: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # velociraptor, etc.
|
||||
row_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
column_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
normalized_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
ioc_columns: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) # auto-detected IOC columns
|
||||
file_size_bytes: Mapped[int] = mapped_column(Integer, default=0)
|
||||
encoding: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
delimiter: Mapped[Optional[str]] = mapped_column(String(4), nullable=True)
|
||||
time_range_start: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
time_range_end: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
uploaded_by: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="datasets", lazy="selectin")
|
||||
rows: Mapped[list["DatasetRow"]] = relationship(
|
||||
back_populates="dataset", lazy="noload", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_datasets_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
class DatasetRow(Base):
|
||||
"""Individual row from a CSV dataset, stored as JSON blob."""
|
||||
__tablename__ = "dataset_rows"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
dataset_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
row_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
data: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
normalized_data: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
# relationships
|
||||
dataset: Mapped["Dataset"] = relationship(back_populates="rows")
|
||||
annotations: Mapped[list["Annotation"]] = relationship(
|
||||
back_populates="row", lazy="noload"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_dataset_rows_dataset", "dataset_id"),
|
||||
Index("ix_dataset_rows_dataset_idx", "dataset_id", "row_index"),
|
||||
)
|
||||
|
||||
|
||||
# ── Conversations ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
title: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="conversations", lazy="selectin")
|
||||
messages: Mapped[list["Message"]] = relationship(
|
||||
back_populates="conversation", lazy="selectin", cascade="all, delete-orphan",
|
||||
order_by="Message.created_at",
|
||||
)
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
conversation_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("conversations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
role: Mapped[str] = mapped_column(String(16), nullable=False) # user | agent | system
|
||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
model_used: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
node_used: Mapped[Optional[str]] = mapped_column(String(64), nullable=True) # wile | roadrunner | cluster
|
||||
token_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
latency_ms: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
response_meta: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
conversation: Mapped["Conversation"] = relationship(back_populates="messages")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_messages_conversation", "conversation_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Annotations ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Annotation(Base):
|
||||
__tablename__ = "annotations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
row_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, ForeignKey("dataset_rows.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
author_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
severity: Mapped[str] = mapped_column(
|
||||
String(16), default="info"
|
||||
) # info | low | medium | high | critical
|
||||
tag: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), nullable=True
|
||||
) # suspicious | benign | needs-review
|
||||
highlight_color: Mapped[Optional[str]] = mapped_column(String(16), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
row: Mapped[Optional["DatasetRow"]] = relationship(back_populates="annotations")
|
||||
author: Mapped[Optional["User"]] = relationship(back_populates="annotations")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_annotations_dataset", "dataset_id"),
|
||||
Index("ix_annotations_row", "row_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Hypotheses ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Hypothesis(Base):
|
||||
__tablename__ = "hypotheses"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
mitre_technique: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(16), default="draft"
|
||||
) # draft | active | confirmed | rejected
|
||||
evidence_row_ids: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
|
||||
evidence_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
# relationships
|
||||
hunt: Mapped[Optional["Hunt"]] = relationship(back_populates="hypotheses", lazy="selectin")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_hypotheses_hunt", "hunt_id"),
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Results ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentResult(Base):
|
||||
__tablename__ = "enrichment_results"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
ioc_value: Mapped[str] = mapped_column(String(512), nullable=False, index=True)
|
||||
ioc_type: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False
|
||||
) # ip | hash_md5 | hash_sha1 | hash_sha256 | domain | url
|
||||
source: Mapped[str] = mapped_column(String(32), nullable=False) # virustotal | abuseipdb | shodan | ai
|
||||
verdict: Mapped[Optional[str]] = mapped_column(
|
||||
String(16), nullable=True
|
||||
) # clean | suspicious | malicious | unknown
|
||||
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
raw_result: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
|
||||
summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id"), nullable=True
|
||||
)
|
||||
cached_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_enrichment_ioc_source", "ioc_value", "source"),
|
||||
)
|
||||
|
||||
|
||||
# ── AUP Keyword Themes & Keywords ────────────────────────────────────
|
||||
|
||||
|
||||
class KeywordTheme(Base):
|
||||
"""A named category of keywords for AUP scanning (e.g. gambling, gaming)."""
|
||||
__tablename__ = "keyword_themes"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
|
||||
color: Mapped[str] = mapped_column(String(16), default="#9e9e9e") # hex chip color
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
is_builtin: Mapped[bool] = mapped_column(Boolean, default=False) # seed-provided
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
keywords: Mapped[list["Keyword"]] = relationship(
|
||||
back_populates="theme", lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Keyword(Base):
|
||||
"""Individual keyword / pattern belonging to a theme."""
|
||||
__tablename__ = "keywords"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
theme_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("keyword_themes.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
value: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
is_regex: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
# relationships
|
||||
theme: Mapped["KeywordTheme"] = relationship(back_populates="keywords")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_keywords_theme", "theme_id"),
|
||||
Index("ix_keywords_value", "value"),
|
||||
)
|
||||
1
backend/app/db/repositories/__init__.py
Normal file
1
backend/app/db/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repositories package — typed CRUD operations for each model."""
|
||||
127
backend/app/db/repositories/datasets.py
Normal file
127
backend/app/db/repositories/datasets.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Dataset repository — CRUD operations for datasets and their rows."""
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRepository:
|
||||
"""Typed CRUD for Dataset and DatasetRow models."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
# ── Dataset CRUD ──────────────────────────────────────────────────
|
||||
|
||||
async def create_dataset(self, **kwargs) -> Dataset:
|
||||
ds = Dataset(**kwargs)
|
||||
self.session.add(ds)
|
||||
await self.session.flush()
|
||||
return ds
|
||||
|
||||
async def get_dataset(self, dataset_id: str) -> Dataset | None:
|
||||
result = await self.session.execute(
|
||||
select(Dataset).where(Dataset.id == dataset_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
hunt_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> Sequence[Dataset]:
|
||||
stmt = select(Dataset).order_by(Dataset.created_at.desc())
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_datasets(self, hunt_id: str | None = None) -> int:
|
||||
stmt = select(func.count(Dataset.id))
|
||||
if hunt_id:
|
||||
stmt = stmt.where(Dataset.hunt_id == hunt_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def delete_dataset(self, dataset_id: str) -> bool:
|
||||
ds = await self.get_dataset(dataset_id)
|
||||
if not ds:
|
||||
return False
|
||||
await self.session.delete(ds)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
||||
# ── Row CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
async def bulk_insert_rows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows: list[dict],
|
||||
normalized_rows: list[dict] | None = None,
|
||||
batch_size: int = 500,
|
||||
) -> int:
|
||||
"""Insert rows in batches. Returns count inserted."""
|
||||
count = 0
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
norm_batch = normalized_rows[i : i + batch_size] if normalized_rows else [None] * len(batch)
|
||||
objects = [
|
||||
DatasetRow(
|
||||
dataset_id=dataset_id,
|
||||
row_index=i + j,
|
||||
data=row,
|
||||
normalized_data=norm,
|
||||
)
|
||||
for j, (row, norm) in enumerate(zip(batch, norm_batch))
|
||||
]
|
||||
self.session.add_all(objects)
|
||||
await self.session.flush()
|
||||
count += len(objects)
|
||||
return count
|
||||
|
||||
async def get_rows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
limit: int = 1000,
|
||||
offset: int = 0,
|
||||
) -> Sequence[DatasetRow]:
|
||||
stmt = (
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def count_rows(self, dataset_id: str) -> int:
|
||||
stmt = select(func.count(DatasetRow.id)).where(
|
||||
DatasetRow.dataset_id == dataset_id
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_row_by_index(
|
||||
self, dataset_id: str, row_index: int
|
||||
) -> DatasetRow | None:
|
||||
stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset_id,
|
||||
DatasetRow.row_index == row_index,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_rows(self, dataset_id: str) -> int:
|
||||
result = await self.session.execute(
|
||||
delete(DatasetRow).where(DatasetRow.dataset_id == dataset_id)
|
||||
)
|
||||
return result.rowcount # type: ignore[return-value]
|
||||
@@ -1,28 +1,79 @@
|
||||
"""ThreatHunt backend application."""
|
||||
"""ThreatHunt backend application.
|
||||
|
||||
Wires together: database, CORS, agent routes, dataset routes, hunt routes,
|
||||
annotation/hypothesis routes. DB tables are auto-created on startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import agent
|
||||
from app.config import settings
|
||||
from app.db import init_db, dispose_db
|
||||
from app.api.routes.agent_v2 import router as agent_router
|
||||
from app.api.routes.datasets import router as datasets_router
|
||||
from app.api.routes.hunts import router as hunts_router
|
||||
from app.api.routes.annotations import ann_router, hyp_router
|
||||
from app.api.routes.enrichment import router as enrichment_router
|
||||
from app.api.routes.correlation import router as correlation_router
|
||||
from app.api.routes.reports import router as reports_router
|
||||
from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup / shutdown lifecycle."""
|
||||
logger.info("Starting ThreatHunt API …")
|
||||
await init_db()
|
||||
logger.info("Database initialised")
|
||||
# Seed default AUP keyword themes
|
||||
from app.db import async_session_factory
|
||||
from app.services.keyword_defaults import seed_defaults
|
||||
async with async_session_factory() as seed_db:
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
yield
|
||||
logger.info("Shutting down …")
|
||||
from app.agents.providers_v2 import cleanup_client
|
||||
from app.services.enrichment import enrichment_engine
|
||||
await cleanup_client()
|
||||
await enrichment_engine.cleanup()
|
||||
await dispose_db()
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="ThreatHunt API",
|
||||
description="Analyst-assist threat hunting platform with agent guidance",
|
||||
version="0.1.0",
|
||||
description="Analyst-assist threat hunting platform powered by Wile & Roadrunner LLM cluster",
|
||||
version="0.3.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, restrict to known domains
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routes
|
||||
app.include_router(agent.router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(agent_router)
|
||||
app.include_router(datasets_router)
|
||||
app.include_router(hunts_router)
|
||||
app.include_router(ann_router)
|
||||
app.include_router(hyp_router)
|
||||
app.include_router(enrichment_router)
|
||||
app.include_router(correlation_router)
|
||||
app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
@@ -30,6 +81,12 @@ async def root():
|
||||
"""API health check."""
|
||||
return {
|
||||
"service": "ThreatHunt API",
|
||||
"version": settings.APP_VERSION,
|
||||
"status": "running",
|
||||
"docs": "/docs",
|
||||
"cluster": {
|
||||
"wile": settings.wile_url,
|
||||
"roadrunner": settings.roadrunner_url,
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Services package."""
|
||||
201
backend/app/services/auth.py
Normal file
201
backend/app/services/auth.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Authentication & security — JWT tokens, password hashing, role-based access.
|
||||
|
||||
Provides:
|
||||
- Password hashing (bcrypt via passlib)
|
||||
- JWT access/refresh token creation and verification
|
||||
- FastAPI dependency for protecting routes
|
||||
- Role-based enforcement (analyst, admin, viewer)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Password hashing ─────────────────────────────────────────────────
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain: str, hashed: str) -> bool:
|
||||
return pwd_context.verify(plain, hashed)
|
||||
|
||||
|
||||
# ── JWT tokens ────────────────────────────────────────────────────────
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # user_id
|
||||
role: str
|
||||
exp: datetime
|
||||
type: str # "access" or "refresh"
|
||||
|
||||
|
||||
def create_access_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_MINUTES
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: str, role: str) -> str:
|
||||
expires = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_DAYS
|
||||
)
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"role": role,
|
||||
"exp": expires,
|
||||
"type": "refresh",
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_token_pair(user_id: str, role: str) -> TokenPair:
|
||||
return TokenPair(
|
||||
access_token=create_access_token(user_id, role),
|
||||
refresh_token=create_refresh_token(user_id, role),
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_MINUTES * 60,
|
||||
)
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload:
|
||||
"""Decode and validate a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[ALGORITHM])
|
||||
return TokenPayload(**payload)
|
||||
except JWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Invalid token: {e}",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
# ── FastAPI dependencies ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Extract and validate the current user from JWT.
|
||||
|
||||
When AUTH is disabled (no JWT secret configured), returns a default analyst user.
|
||||
"""
|
||||
# If auth is disabled (dev mode), return a default user
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication required",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token_data = decode_token(credentials.credentials)
|
||||
|
||||
if token_data.type != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use access token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User account is disabled",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Optional[User]:
|
||||
"""Like get_current_user, but returns None instead of raising if no token."""
|
||||
if not credentials:
|
||||
if settings.JWT_SECRET == "CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET":
|
||||
return User(
|
||||
id="dev-user",
|
||||
username="analyst",
|
||||
email="analyst@local",
|
||||
role="analyst",
|
||||
display_name="Dev Analyst",
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def require_role(*roles: str):
|
||||
"""Dependency factory that requires the current user to have one of the specified roles."""
|
||||
|
||||
async def _check(user: User = Depends(get_current_user)) -> User:
|
||||
if user.role not in roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires one of roles: {', '.join(roles)}. You have: {user.role}",
|
||||
)
|
||||
return user
|
||||
|
||||
return _check
|
||||
|
||||
|
||||
# Convenience dependencies
|
||||
require_analyst = require_role("analyst", "admin")
|
||||
require_admin = require_role("admin")
|
||||
400
backend/app/services/correlation.py
Normal file
400
backend/app/services/correlation.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Cross-hunt correlation engine — find IOC overlaps, timeline patterns, and shared TTPs.
|
||||
|
||||
Identifies connections between hunts by analyzing:
|
||||
1. Shared IOC values across datasets
|
||||
2. Overlapping time ranges and temporal proximity
|
||||
3. Common MITRE ATT&CK techniques across hypotheses
|
||||
4. Host-to-host lateral movement patterns
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow, Hunt, Hypothesis, EnrichmentResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IOCOverlap:
|
||||
"""Shared IOC between two or more hunts/datasets."""
|
||||
ioc_value: str
|
||||
ioc_type: str
|
||||
datasets: list[dict] = field(default_factory=list) # [{dataset_id, hunt_id, name}]
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
count: int = 0
|
||||
enrichment_verdict: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeOverlap:
|
||||
"""Overlapping time window between datasets."""
|
||||
dataset_a: dict = field(default_factory=dict)
|
||||
dataset_b: dict = field(default_factory=dict)
|
||||
overlap_start: str = ""
|
||||
overlap_end: str = ""
|
||||
overlap_hours: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TechniqueOverlap:
|
||||
"""Shared MITRE ATT&CK technique across hunts."""
|
||||
technique_id: str
|
||||
technique_name: str = ""
|
||||
hypotheses: list[dict] = field(default_factory=list)
|
||||
hunt_ids: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrelationResult:
|
||||
"""Complete correlation analysis result."""
|
||||
hunt_ids: list[str]
|
||||
ioc_overlaps: list[IOCOverlap] = field(default_factory=list)
|
||||
time_overlaps: list[TimeOverlap] = field(default_factory=list)
|
||||
technique_overlaps: list[TechniqueOverlap] = field(default_factory=list)
|
||||
host_overlaps: list[dict] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
total_correlations: int = 0
|
||||
|
||||
|
||||
class CorrelationEngine:
|
||||
"""Engine for finding correlations across hunts and datasets."""
|
||||
|
||||
async def correlate_hunts(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> CorrelationResult:
|
||||
"""Run full correlation analysis across specified hunts."""
|
||||
result = CorrelationResult(hunt_ids=hunt_ids)
|
||||
|
||||
# Run all correlation types
|
||||
result.ioc_overlaps = await self._find_ioc_overlaps(hunt_ids, db)
|
||||
result.time_overlaps = await self._find_time_overlaps(hunt_ids, db)
|
||||
result.technique_overlaps = await self._find_technique_overlaps(hunt_ids, db)
|
||||
result.host_overlaps = await self._find_host_overlaps(hunt_ids, db)
|
||||
|
||||
result.total_correlations = (
|
||||
len(result.ioc_overlaps)
|
||||
+ len(result.time_overlaps)
|
||||
+ len(result.technique_overlaps)
|
||||
+ len(result.host_overlaps)
|
||||
)
|
||||
|
||||
result.summary = self._build_summary(result)
|
||||
return result
|
||||
|
||||
async def correlate_all(self, db: AsyncSession) -> CorrelationResult:
|
||||
"""Correlate across ALL hunts in the system."""
|
||||
stmt = select(Hunt.id)
|
||||
result = await db.execute(stmt)
|
||||
hunt_ids = [row[0] for row in result.fetchall()]
|
||||
|
||||
if len(hunt_ids) < 2:
|
||||
return CorrelationResult(
|
||||
hunt_ids=hunt_ids,
|
||||
summary="Need at least 2 hunts for correlation analysis.",
|
||||
)
|
||||
|
||||
return await self.correlate_hunts(hunt_ids, db)
|
||||
|
||||
async def find_ioc_across_hunts(
|
||||
self,
|
||||
ioc_value: str,
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find all occurrences of a specific IOC across all datasets/hunts."""
|
||||
# Search in dataset rows using JSON contains
|
||||
stmt = select(DatasetRow, Dataset).join(
|
||||
Dataset, DatasetRow.dataset_id == Dataset.id
|
||||
)
|
||||
result = await db.execute(stmt.limit(5000))
|
||||
rows = result.all()
|
||||
|
||||
occurrences = []
|
||||
for row, dataset in rows:
|
||||
data = row.data or {}
|
||||
normalized = row.normalized_data or {}
|
||||
|
||||
# Search both raw and normalized data
|
||||
for col, val in {**data, **normalized}.items():
|
||||
if str(val) == ioc_value:
|
||||
occurrences.append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"row_index": row.row_index,
|
||||
"column": col,
|
||||
})
|
||||
break
|
||||
|
||||
return occurrences
|
||||
|
||||
# ── IOC overlap detection ─────────────────────────────────────────
|
||||
|
||||
async def _find_ioc_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[IOCOverlap]:
|
||||
"""Find IOC values that appear in datasets from different hunts."""
|
||||
# Get all datasets for the specified hunts
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
if len(datasets) < 2:
|
||||
return []
|
||||
|
||||
# Build IOC → dataset mapping
|
||||
ioc_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
if not dataset.ioc_columns:
|
||||
continue
|
||||
|
||||
ioc_cols = list(dataset.ioc_columns.keys())
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in ioc_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
ioc_map[str(val).strip()].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
"column": col,
|
||||
"ioc_type": dataset.ioc_columns.get(col, "unknown"),
|
||||
})
|
||||
|
||||
# Filter to IOCs appearing in multiple hunts
|
||||
overlaps = []
|
||||
for ioc_value, appearances in ioc_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
# Check for enrichment data
|
||||
enrich_stmt = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value == ioc_value
|
||||
).limit(1)
|
||||
enrich_result = await db.execute(enrich_stmt)
|
||||
enrichment = enrich_result.scalar_one_or_none()
|
||||
|
||||
overlaps.append(IOCOverlap(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=appearances[0].get("ioc_type", "unknown"),
|
||||
datasets=appearances,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
count=len(appearances),
|
||||
enrichment_verdict=enrichment.verdict if enrichment else "",
|
||||
))
|
||||
|
||||
# Sort by count descending
|
||||
overlaps.sort(key=lambda x: x.count, reverse=True)
|
||||
return overlaps[:100] # Limit results
|
||||
|
||||
# ── Time window overlap ───────────────────────────────────────────
|
||||
|
||||
async def _find_time_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TimeOverlap]:
|
||||
"""Find datasets across hunts with overlapping time ranges."""
|
||||
stmt = select(Dataset).where(
|
||||
Dataset.hunt_id.in_(hunt_ids),
|
||||
Dataset.time_range_start.isnot(None),
|
||||
Dataset.time_range_end.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
overlaps = []
|
||||
for i, ds_a in enumerate(datasets):
|
||||
for ds_b in datasets[i + 1:]:
|
||||
if ds_a.hunt_id == ds_b.hunt_id:
|
||||
continue # Same hunt, skip
|
||||
|
||||
try:
|
||||
a_start = datetime.fromisoformat(ds_a.time_range_start)
|
||||
a_end = datetime.fromisoformat(ds_a.time_range_end)
|
||||
b_start = datetime.fromisoformat(ds_b.time_range_start)
|
||||
b_end = datetime.fromisoformat(ds_b.time_range_end)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Check overlap
|
||||
overlap_start = max(a_start, b_start)
|
||||
overlap_end = min(a_end, b_end)
|
||||
|
||||
if overlap_start < overlap_end:
|
||||
hours = (overlap_end - overlap_start).total_seconds() / 3600
|
||||
overlaps.append(TimeOverlap(
|
||||
dataset_a={
|
||||
"id": ds_a.id,
|
||||
"name": ds_a.name,
|
||||
"hunt_id": ds_a.hunt_id,
|
||||
"start": ds_a.time_range_start,
|
||||
"end": ds_a.time_range_end,
|
||||
},
|
||||
dataset_b={
|
||||
"id": ds_b.id,
|
||||
"name": ds_b.name,
|
||||
"hunt_id": ds_b.hunt_id,
|
||||
"start": ds_b.time_range_start,
|
||||
"end": ds_b.time_range_end,
|
||||
},
|
||||
overlap_start=overlap_start.isoformat(),
|
||||
overlap_end=overlap_end.isoformat(),
|
||||
overlap_hours=round(hours, 2),
|
||||
))
|
||||
|
||||
overlaps.sort(key=lambda x: x.overlap_hours, reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── MITRE technique overlap ───────────────────────────────────────
|
||||
|
||||
async def _find_technique_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[TechniqueOverlap]:
|
||||
"""Find MITRE ATT&CK techniques shared across hunts."""
|
||||
stmt = select(Hypothesis).where(
|
||||
Hypothesis.hunt_id.in_(hunt_ids),
|
||||
Hypothesis.mitre_technique.isnot(None),
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
hypotheses = result.scalars().all()
|
||||
|
||||
technique_map: dict[str, list[dict]] = defaultdict(list)
|
||||
for hyp in hypotheses:
|
||||
technique = hyp.mitre_technique.strip()
|
||||
if technique:
|
||||
technique_map[technique].append({
|
||||
"hypothesis_id": hyp.id,
|
||||
"hypothesis_title": hyp.title,
|
||||
"hunt_id": hyp.hunt_id,
|
||||
"status": hyp.status,
|
||||
})
|
||||
|
||||
overlaps = []
|
||||
for technique, hyps in technique_map.items():
|
||||
hunt_set = set(h["hunt_id"] for h in hyps if h["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append(TechniqueOverlap(
|
||||
technique_id=technique,
|
||||
hypotheses=hyps,
|
||||
hunt_ids=sorted(hunt_set),
|
||||
))
|
||||
|
||||
return overlaps
|
||||
|
||||
# ── Host overlap ──────────────────────────────────────────────────
|
||||
|
||||
async def _find_host_overlaps(
|
||||
self,
|
||||
hunt_ids: list[str],
|
||||
db: AsyncSession,
|
||||
) -> list[dict]:
|
||||
"""Find hostnames that appear in datasets from different hunts.
|
||||
|
||||
Useful for detecting lateral movement patterns.
|
||||
"""
|
||||
stmt = select(Dataset).where(Dataset.hunt_id.in_(hunt_ids))
|
||||
result = await db.execute(stmt)
|
||||
datasets = result.scalars().all()
|
||||
|
||||
host_map: dict[str, list[dict]] = defaultdict(list)
|
||||
|
||||
for dataset in datasets:
|
||||
norm_cols = dataset.normalized_columns or {}
|
||||
# Look for hostname columns
|
||||
hostname_cols = [
|
||||
orig for orig, canon in norm_cols.items()
|
||||
if canon in ("hostname", "host", "computer_name", "src_host", "dst_host")
|
||||
]
|
||||
if not hostname_cols:
|
||||
continue
|
||||
|
||||
rows_stmt = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id == dataset.id
|
||||
).limit(2000)
|
||||
rows_result = await db.execute(rows_stmt)
|
||||
rows = rows_result.scalars().all()
|
||||
|
||||
for row in rows:
|
||||
data = row.data or {}
|
||||
for col in hostname_cols:
|
||||
val = data.get(col, "")
|
||||
if val and str(val).strip():
|
||||
host_name = str(val).strip().upper()
|
||||
host_map[host_name].append({
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"hunt_id": dataset.hunt_id,
|
||||
})
|
||||
|
||||
# Filter to hosts appearing in multiple hunts
|
||||
overlaps = []
|
||||
for host, appearances in host_map.items():
|
||||
hunt_set = set(a["hunt_id"] for a in appearances if a["hunt_id"])
|
||||
if len(hunt_set) >= 2:
|
||||
overlaps.append({
|
||||
"hostname": host,
|
||||
"hunt_ids": sorted(hunt_set),
|
||||
"dataset_count": len(appearances),
|
||||
"datasets": appearances[:10],
|
||||
})
|
||||
|
||||
overlaps.sort(key=lambda x: x["dataset_count"], reverse=True)
|
||||
return overlaps[:50]
|
||||
|
||||
# ── Summary builder ───────────────────────────────────────────────
|
||||
|
||||
def _build_summary(self, result: CorrelationResult) -> str:
|
||||
"""Build a human-readable summary of correlations."""
|
||||
parts = [f"Correlation analysis across {len(result.hunt_ids)} hunts:"]
|
||||
|
||||
if result.ioc_overlaps:
|
||||
malicious = [o for o in result.ioc_overlaps if o.enrichment_verdict == "malicious"]
|
||||
parts.append(
|
||||
f" - {len(result.ioc_overlaps)} shared IOCs "
|
||||
f"({len(malicious)} flagged malicious)"
|
||||
)
|
||||
else:
|
||||
parts.append(" - No shared IOCs found")
|
||||
|
||||
if result.time_overlaps:
|
||||
parts.append(f" - {len(result.time_overlaps)} overlapping time windows")
|
||||
|
||||
if result.technique_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.technique_overlaps)} shared MITRE techniques"
|
||||
)
|
||||
|
||||
if result.host_overlaps:
|
||||
parts.append(
|
||||
f" - {len(result.host_overlaps)} hosts appearing in multiple hunts "
|
||||
"(potential lateral movement)"
|
||||
)
|
||||
|
||||
if result.total_correlations == 0:
|
||||
parts.append(" No significant correlations detected.")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# Singleton
|
||||
correlation_engine = CorrelationEngine()
|
||||
165
backend/app/services/csv_parser.py
Normal file
165
backend/app/services/csv_parser.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""CSV parsing engine with encoding detection, delimiter sniffing, and streaming.
|
||||
|
||||
Handles large Velociraptor CSV exports with resilience to encoding issues,
|
||||
varied delimiters, and malformed rows.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator
|
||||
|
||||
import chardet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reasonable defaults
|
||||
MAX_FIELD_SIZE = 1024 * 1024 # 1 MB per field
|
||||
csv.field_size_limit(MAX_FIELD_SIZE)
|
||||
|
||||
|
||||
def detect_encoding(file_bytes: bytes, sample_size: int = 65536) -> str:
|
||||
"""Detect file encoding from a sample of bytes."""
|
||||
result = chardet.detect(file_bytes[:sample_size])
|
||||
encoding = result.get("encoding", "utf-8") or "utf-8"
|
||||
confidence = result.get("confidence", 0)
|
||||
logger.info(f"Detected encoding: {encoding} (confidence: {confidence:.2f})")
|
||||
# Fall back to utf-8 if confidence is very low
|
||||
if confidence < 0.5:
|
||||
encoding = "utf-8"
|
||||
return encoding
|
||||
|
||||
|
||||
def detect_delimiter(text_sample: str) -> str:
|
||||
"""Sniff the CSV delimiter from a text sample."""
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(text_sample, delimiters=",\t;|")
|
||||
return dialect.delimiter
|
||||
except csv.Error:
|
||||
return ","
|
||||
|
||||
|
||||
def infer_column_types(rows: list[dict], sample_size: int = 100) -> dict[str, str]:
|
||||
"""Infer column types from a sample of rows.
|
||||
|
||||
Returns a mapping of column_name → type_hint where type_hint is one of:
|
||||
timestamp, integer, float, ip, hash_md5, hash_sha1, hash_sha256, domain, path, string
|
||||
"""
|
||||
import re
|
||||
|
||||
type_map: dict[str, dict[str, int]] = {}
|
||||
sample = rows[:sample_size]
|
||||
|
||||
patterns = {
|
||||
"ip": re.compile(
|
||||
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$"
|
||||
),
|
||||
"hash_md5": re.compile(r"^[a-fA-F0-9]{32}$"),
|
||||
"hash_sha1": re.compile(r"^[a-fA-F0-9]{40}$"),
|
||||
"hash_sha256": re.compile(r"^[a-fA-F0-9]{64}$"),
|
||||
"integer": re.compile(r"^-?\d+$"),
|
||||
"float": re.compile(r"^-?\d+\.\d+$"),
|
||||
"timestamp": re.compile(
|
||||
r"^\d{4}[-/]\d{2}[-/]\d{2}[T ]\d{2}:\d{2}"
|
||||
),
|
||||
"domain": re.compile(
|
||||
r"^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?(\.[a-zA-Z]{2,})+$"
|
||||
),
|
||||
"path": re.compile(r"^([A-Z]:\\|/)", re.IGNORECASE),
|
||||
}
|
||||
|
||||
for row in sample:
|
||||
for col, val in row.items():
|
||||
if col not in type_map:
|
||||
type_map[col] = {}
|
||||
val_str = str(val).strip()
|
||||
if not val_str:
|
||||
continue
|
||||
matched = False
|
||||
for type_name, pattern in patterns.items():
|
||||
if pattern.match(val_str):
|
||||
type_map[col][type_name] = type_map[col].get(type_name, 0) + 1
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
type_map[col]["string"] = type_map[col].get("string", 0) + 1
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for col, counts in type_map.items():
|
||||
if counts:
|
||||
result[col] = max(counts, key=counts.get) # type: ignore[arg-type]
|
||||
else:
|
||||
result[col] = "string"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_csv_bytes(
|
||||
raw_bytes: bytes,
|
||||
max_rows: int | None = None,
|
||||
) -> tuple[list[dict], dict]:
|
||||
"""Parse a CSV file from raw bytes.
|
||||
|
||||
Returns:
|
||||
(rows, metadata) where metadata contains encoding, delimiter, columns, etc.
|
||||
"""
|
||||
encoding = detect_encoding(raw_bytes)
|
||||
|
||||
try:
|
||||
text = raw_bytes.decode(encoding, errors="replace")
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
text = raw_bytes.decode("utf-8", errors="replace")
|
||||
encoding = "utf-8"
|
||||
|
||||
# Detect delimiter from first few KB
|
||||
delimiter = detect_delimiter(text[:8192])
|
||||
|
||||
reader = csv.DictReader(io.StringIO(text), delimiter=delimiter)
|
||||
columns = reader.fieldnames or []
|
||||
|
||||
rows: list[dict] = []
|
||||
for i, row in enumerate(reader):
|
||||
if max_rows is not None and i >= max_rows:
|
||||
break
|
||||
rows.append(dict(row))
|
||||
|
||||
column_types = infer_column_types(rows) if rows else {}
|
||||
|
||||
metadata = {
|
||||
"encoding": encoding,
|
||||
"delimiter": delimiter,
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
"row_count": len(rows),
|
||||
"total_rows_in_file": len(rows), # same when no max_rows
|
||||
}
|
||||
|
||||
return rows, metadata
|
||||
|
||||
|
||||
async def parse_csv_streaming(
|
||||
file_path: Path,
|
||||
chunk_size: int = 8192,
|
||||
) -> AsyncIterator[tuple[int, dict]]:
|
||||
"""Stream-parse a CSV file yielding (row_index, row_dict) tuples.
|
||||
|
||||
Memory-efficient for large files.
|
||||
"""
|
||||
import aiofiles # type: ignore[import-untyped]
|
||||
|
||||
# Read a sample for encoding/delimiter detection
|
||||
with open(file_path, "rb") as f:
|
||||
sample_bytes = f.read(65536)
|
||||
|
||||
encoding = detect_encoding(sample_bytes)
|
||||
text_sample = sample_bytes.decode(encoding, errors="replace")
|
||||
delimiter = detect_delimiter(text_sample[:8192])
|
||||
|
||||
# Now stream-read
|
||||
async with aiofiles.open(file_path, mode="r", encoding=encoding, errors="replace") as f:
|
||||
content = await f.read() # For DictReader compatibility
|
||||
|
||||
reader = csv.DictReader(io.StringIO(content), delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
yield i, dict(row)
|
||||
655
backend/app/services/enrichment.py
Normal file
655
backend/app/services/enrichment.py
Normal file
@@ -0,0 +1,655 @@
|
||||
"""IOC Enrichment Engine — VirusTotal, AbuseIPDB, Shodan integrations.
|
||||
|
||||
Provides automated IOC enrichment with caching and rate limiting.
|
||||
Enriches IPs, hashes, domains with threat intelligence verdicts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db.models import EnrichmentResult as EnrichmentDB
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IOCType(str, Enum):
|
||||
IP = "ip"
|
||||
DOMAIN = "domain"
|
||||
HASH_MD5 = "hash_md5"
|
||||
HASH_SHA1 = "hash_sha1"
|
||||
HASH_SHA256 = "hash_sha256"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class Verdict(str, Enum):
|
||||
CLEAN = "clean"
|
||||
SUSPICIOUS = "suspicious"
|
||||
MALICIOUS = "malicious"
|
||||
UNKNOWN = "unknown"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichmentResultData:
|
||||
"""Enrichment result from a provider."""
|
||||
ioc_value: str
|
||||
ioc_type: IOCType
|
||||
source: str
|
||||
verdict: Verdict
|
||||
score: float = 0.0 # 0-100 normalized threat score
|
||||
raw_data: dict = field(default_factory=dict)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
country: str = ""
|
||||
asn: str = ""
|
||||
org: str = ""
|
||||
last_seen: str = ""
|
||||
error: str = ""
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
# ── Rate limiter ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Simple token bucket rate limiter for API calls."""
|
||||
|
||||
def __init__(self, calls_per_minute: int = 4):
|
||||
self.calls_per_minute = calls_per_minute
|
||||
self.interval = 60.0 / calls_per_minute
|
||||
self._last_call: float = 0.0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self):
|
||||
async with self._lock:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_call
|
||||
if elapsed < self.interval:
|
||||
await asyncio.sleep(self.interval - elapsed)
|
||||
self._last_call = time.monotonic()
|
||||
|
||||
|
||||
# ── Provider base ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentProvider:
|
||||
"""Base class for enrichment providers."""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, api_key: str = "", rate_limit: int = 4):
|
||||
self.api_key = api_key
|
||||
self.rate_limiter = RateLimiter(rate_limit)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=5),
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def cleanup(self):
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(self.api_key)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ── VirusTotal ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class VirusTotalProvider(EnrichmentProvider):
|
||||
"""VirusTotal v3 API provider."""
|
||||
|
||||
name = "virustotal"
|
||||
BASE_URL = "https://www.virustotal.com/api/v3"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.VIRUSTOTAL_API_KEY, rate_limit=4)
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {"x-apikey": self.api_key}
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="VirusTotal API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
endpoint = self._get_endpoint(ioc_value, ioc_type)
|
||||
if not endpoint:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"Unsupported IOC type: {ioc_type}",
|
||||
)
|
||||
|
||||
client = self._get_client()
|
||||
resp = await client.get(endpoint, headers=self._headers())
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
attrs = data.get("data", {}).get("attributes", {})
|
||||
stats = attrs.get("last_analysis_stats", {})
|
||||
|
||||
malicious = stats.get("malicious", 0)
|
||||
suspicious = stats.get("suspicious", 0)
|
||||
total = sum(stats.values()) if stats else 0
|
||||
|
||||
# Determine verdict
|
||||
if malicious > 3:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif malicious > 0 or suspicious > 2:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total > 0:
|
||||
verdict = Verdict.CLEAN
|
||||
else:
|
||||
verdict = Verdict.UNKNOWN
|
||||
|
||||
score = (malicious / total * 100) if total > 0 else 0
|
||||
|
||||
tags = attrs.get("tags", [])
|
||||
if attrs.get("type_description"):
|
||||
tags.append(attrs["type_description"])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=round(score, 1),
|
||||
raw_data={
|
||||
"stats": stats,
|
||||
"reputation": attrs.get("reputation", 0),
|
||||
"type_description": attrs.get("type_description", ""),
|
||||
"names": attrs.get("names", [])[:5],
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=attrs.get("country", ""),
|
||||
asn=str(attrs.get("asn", "")),
|
||||
org=attrs.get("as_owner", ""),
|
||||
last_seen=attrs.get("last_analysis_date", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=f"HTTP {e.response.status_code}",
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
def _get_endpoint(self, ioc_value: str, ioc_type: IOCType) -> str | None:
|
||||
if ioc_type == IOCType.IP:
|
||||
return f"{self.BASE_URL}/ip_addresses/{ioc_value}"
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
return f"{self.BASE_URL}/domains/{ioc_value}"
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
return f"{self.BASE_URL}/files/{ioc_value}"
|
||||
elif ioc_type == IOCType.URL:
|
||||
url_id = hashlib.sha256(ioc_value.encode()).hexdigest()
|
||||
return f"{self.BASE_URL}/urls/{url_id}"
|
||||
return None
|
||||
|
||||
|
||||
# ── AbuseIPDB ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class AbuseIPDBProvider(EnrichmentProvider):
|
||||
"""AbuseIPDB API provider — IP reputation."""
|
||||
|
||||
name = "abuseipdb"
|
||||
BASE_URL = "https://api.abuseipdb.com/api/v2"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.ABUSEIPDB_API_KEY, rate_limit=10)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="AbuseIPDB API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/check",
|
||||
params={"ipAddress": ioc_value, "maxAgeInDays": 90, "verbose": "true"},
|
||||
headers={"Key": self.api_key, "Accept": "application/json"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
resp.raise_for_status()
|
||||
data = resp.json().get("data", {})
|
||||
|
||||
abuse_score = data.get("abuseConfidenceScore", 0)
|
||||
total_reports = data.get("totalReports", 0)
|
||||
|
||||
if abuse_score >= 75:
|
||||
verdict = Verdict.MALICIOUS
|
||||
elif abuse_score >= 25 or total_reports > 5:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
elif total_reports == 0:
|
||||
verdict = Verdict.UNKNOWN
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
|
||||
categories = data.get("reports", [])
|
||||
tags = []
|
||||
for report in categories[:10]:
|
||||
for cat_id in report.get("categories", []):
|
||||
tag = self._category_name(cat_id)
|
||||
if tag and tag not in tags:
|
||||
tags.append(tag)
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=float(abuse_score),
|
||||
raw_data={
|
||||
"abuse_confidence_score": abuse_score,
|
||||
"total_reports": total_reports,
|
||||
"is_whitelisted": data.get("isWhitelisted"),
|
||||
"is_tor": data.get("isTor", False),
|
||||
"usage_type": data.get("usageType", ""),
|
||||
"isp": data.get("isp", ""),
|
||||
},
|
||||
tags=tags[:10],
|
||||
country=data.get("countryCode", ""),
|
||||
org=data.get("isp", ""),
|
||||
last_seen=data.get("lastReportedAt", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _category_name(cat_id: int) -> str:
|
||||
categories = {
|
||||
1: "DNS Compromise", 2: "DNS Poisoning", 3: "Fraud Orders",
|
||||
4: "DDoS Attack", 5: "FTP Brute-Force", 6: "Ping of Death",
|
||||
7: "Phishing", 8: "Fraud VoIP", 9: "Open Proxy",
|
||||
10: "Web Spam", 11: "Email Spam", 12: "Blog Spam",
|
||||
13: "VPN IP", 14: "Port Scan", 15: "Hacking",
|
||||
16: "SQL Injection", 17: "Spoofing", 18: "Brute-Force",
|
||||
19: "Bad Web Bot", 20: "Exploited Host", 21: "Web App Attack",
|
||||
22: "SSH", 23: "IoT Targeted",
|
||||
}
|
||||
return categories.get(cat_id, "")
|
||||
|
||||
|
||||
# ── Shodan ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ShodanProvider(EnrichmentProvider):
|
||||
"""Shodan API provider — infrastructure intelligence."""
|
||||
|
||||
name = "shodan"
|
||||
BASE_URL = "https://api.shodan.io"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(api_key=settings.SHODAN_API_KEY, rate_limit=1)
|
||||
|
||||
async def enrich(self, ioc_value: str, ioc_type: IOCType) -> EnrichmentResultData:
|
||||
if ioc_type != IOCType.IP:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan only supports IP lookups",
|
||||
)
|
||||
|
||||
if not self.is_configured:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error="Shodan API key not configured",
|
||||
)
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
f"{self.BASE_URL}/shodan/host/{ioc_value}",
|
||||
params={"key": self.api_key, "minify": "true"},
|
||||
)
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if resp.status_code == 404:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.UNKNOWN,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
ports = data.get("ports", [])
|
||||
vulns = data.get("vulns", [])
|
||||
tags_raw = data.get("tags", [])
|
||||
|
||||
# Determine verdict based on open ports and vulns
|
||||
if vulns:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = min(len(vulns) * 15, 100.0)
|
||||
elif len(ports) > 20:
|
||||
verdict = Verdict.SUSPICIOUS
|
||||
score = 40.0
|
||||
else:
|
||||
verdict = Verdict.CLEAN
|
||||
score = 0.0
|
||||
|
||||
tags = tags_raw[:10]
|
||||
if vulns:
|
||||
tags.extend([f"CVE: {v}" for v in vulns[:5]])
|
||||
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value,
|
||||
ioc_type=ioc_type,
|
||||
source=self.name,
|
||||
verdict=verdict,
|
||||
score=score,
|
||||
raw_data={
|
||||
"ports": ports[:20],
|
||||
"vulns": vulns[:10],
|
||||
"os": data.get("os"),
|
||||
"hostnames": data.get("hostnames", [])[:5],
|
||||
"domains": data.get("domains", [])[:5],
|
||||
"last_update": data.get("last_update", ""),
|
||||
},
|
||||
tags=tags[:15],
|
||||
country=data.get("country_code", ""),
|
||||
asn=data.get("asn", ""),
|
||||
org=data.get("org", ""),
|
||||
last_seen=data.get("last_update", ""),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return EnrichmentResultData(
|
||||
ioc_value=ioc_value, ioc_type=ioc_type,
|
||||
source=self.name, verdict=Verdict.ERROR,
|
||||
error=str(e),
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
|
||||
# ── Enrichment Engine (orchestrator) ──────────────────────────────────
|
||||
|
||||
|
||||
class EnrichmentEngine:
|
||||
"""Orchestrates IOC enrichment across all providers with caching."""
|
||||
|
||||
CACHE_TTL_HOURS = 24
|
||||
|
||||
def __init__(self):
|
||||
self.providers: list[EnrichmentProvider] = [
|
||||
VirusTotalProvider(),
|
||||
AbuseIPDBProvider(),
|
||||
ShodanProvider(),
|
||||
]
|
||||
|
||||
@property
|
||||
def configured_providers(self) -> list[EnrichmentProvider]:
|
||||
return [p for p in self.providers if p.is_configured]
|
||||
|
||||
async def enrich_ioc(
|
||||
self,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
db: AsyncSession | None = None,
|
||||
skip_cache: bool = False,
|
||||
) -> list[EnrichmentResultData]:
|
||||
"""Enrich a single IOC across all configured providers.
|
||||
|
||||
Uses cached results from DB when available.
|
||||
"""
|
||||
results: list[EnrichmentResultData] = []
|
||||
|
||||
# Check cache first
|
||||
if db and not skip_cache:
|
||||
cached = await self._get_cached(db, ioc_value, ioc_type)
|
||||
if cached:
|
||||
logger.info(f"Cache hit for {ioc_type.value}:{ioc_value} ({len(cached)} results)")
|
||||
return cached
|
||||
|
||||
# Query all applicable providers in parallel
|
||||
tasks = []
|
||||
for provider in self.configured_providers:
|
||||
# Skip providers that don't support this IOC type
|
||||
if ioc_type in (IOCType.DOMAIN,) and provider.name in ("abuseipdb", "shodan"):
|
||||
continue
|
||||
if ioc_type == IOCType.IP and provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.IP:
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type in (IOCType.HASH_MD5, IOCType.HASH_SHA1, IOCType.HASH_SHA256):
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.DOMAIN:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
elif ioc_type == IOCType.URL:
|
||||
if provider.name == "virustotal":
|
||||
tasks.append(provider.enrich(ioc_value, ioc_type))
|
||||
|
||||
if tasks:
|
||||
results = list(await asyncio.gather(*tasks, return_exceptions=False))
|
||||
|
||||
# Cache results
|
||||
if db and results:
|
||||
await self._cache_results(db, results)
|
||||
|
||||
return results
|
||||
|
||||
async def enrich_batch(
|
||||
self,
|
||||
iocs: list[tuple[str, IOCType]],
|
||||
db: AsyncSession | None = None,
|
||||
concurrency: int = 3,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Enrich a batch of IOCs with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
all_results: dict[str, list[EnrichmentResultData]] = {}
|
||||
|
||||
async def _enrich_one(value: str, ioc_type: IOCType):
|
||||
async with sem:
|
||||
result = await self.enrich_ioc(value, ioc_type, db=db)
|
||||
all_results[value] = result
|
||||
|
||||
tasks = [_enrich_one(v, t) for v, t in iocs]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return all_results
|
||||
|
||||
async def enrich_dataset_iocs(
|
||||
self,
|
||||
rows: list[dict],
|
||||
ioc_columns: dict,
|
||||
db: AsyncSession | None = None,
|
||||
max_iocs: int = 50,
|
||||
) -> dict[str, list[EnrichmentResultData]]:
|
||||
"""Auto-enrich IOCs found in a dataset.
|
||||
|
||||
Extracts unique IOC values from the identified columns and enriches them.
|
||||
"""
|
||||
iocs_to_enrich: list[tuple[str, IOCType]] = []
|
||||
seen = set()
|
||||
|
||||
for col_name, col_type in ioc_columns.items():
|
||||
ioc_type = self._map_column_type(col_type)
|
||||
if not ioc_type:
|
||||
continue
|
||||
|
||||
for row in rows:
|
||||
value = row.get(col_name, "")
|
||||
if value and value not in seen:
|
||||
seen.add(value)
|
||||
iocs_to_enrich.append((str(value), ioc_type))
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if len(iocs_to_enrich) >= max_iocs:
|
||||
break
|
||||
|
||||
if iocs_to_enrich:
|
||||
return await self.enrich_batch(iocs_to_enrich, db=db)
|
||||
return {}
|
||||
|
||||
async def _get_cached(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
ioc_value: str,
|
||||
ioc_type: IOCType,
|
||||
) -> list[EnrichmentResultData] | None:
|
||||
"""Check for cached enrichment results."""
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=self.CACHE_TTL_HOURS)
|
||||
stmt = (
|
||||
select(EnrichmentDB)
|
||||
.where(
|
||||
EnrichmentDB.ioc_value == ioc_value,
|
||||
EnrichmentDB.ioc_type == ioc_type.value,
|
||||
EnrichmentDB.cached_at >= cutoff,
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
cached = result.scalars().all()
|
||||
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
return [
|
||||
EnrichmentResultData(
|
||||
ioc_value=c.ioc_value,
|
||||
ioc_type=IOCType(c.ioc_type),
|
||||
source=c.source,
|
||||
verdict=Verdict(c.verdict),
|
||||
score=c.score or 0.0,
|
||||
raw_data=c.raw_data or {},
|
||||
tags=c.tags or [],
|
||||
country=c.country or "",
|
||||
asn=c.asn or "",
|
||||
org=c.org or "",
|
||||
)
|
||||
for c in cached
|
||||
]
|
||||
|
||||
async def _cache_results(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
results: list[EnrichmentResultData],
|
||||
):
|
||||
"""Cache enrichment results in the database."""
|
||||
for r in results:
|
||||
if r.verdict == Verdict.ERROR:
|
||||
continue # Don't cache errors
|
||||
entry = EnrichmentDB(
|
||||
ioc_value=r.ioc_value,
|
||||
ioc_type=r.ioc_type.value,
|
||||
source=r.source,
|
||||
verdict=r.verdict.value,
|
||||
score=r.score,
|
||||
raw_data=r.raw_data,
|
||||
tags=r.tags,
|
||||
country=r.country,
|
||||
asn=r.asn,
|
||||
org=r.org,
|
||||
)
|
||||
db.add(entry)
|
||||
try:
|
||||
await db.flush()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache enrichment: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _map_column_type(col_type: str) -> IOCType | None:
|
||||
"""Map column type from normalizer to IOCType."""
|
||||
mapping = {
|
||||
"ip": IOCType.IP,
|
||||
"ip_address": IOCType.IP,
|
||||
"src_ip": IOCType.IP,
|
||||
"dst_ip": IOCType.IP,
|
||||
"domain": IOCType.DOMAIN,
|
||||
"hash_md5": IOCType.HASH_MD5,
|
||||
"hash_sha1": IOCType.HASH_SHA1,
|
||||
"hash_sha256": IOCType.HASH_SHA256,
|
||||
"url": IOCType.URL,
|
||||
}
|
||||
return mapping.get(col_type)
|
||||
|
||||
async def cleanup(self):
|
||||
for provider in self.providers:
|
||||
await provider.cleanup()
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Return enrichment engine status."""
|
||||
return {
|
||||
"providers": {
|
||||
p.name: {"configured": p.is_configured}
|
||||
for p in self.providers
|
||||
},
|
||||
"cache_ttl_hours": self.CACHE_TTL_HOURS,
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
enrichment_engine = EnrichmentEngine()
|
||||
145
backend/app/services/keyword_defaults.py
Normal file
145
backend/app/services/keyword_defaults.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Default AUP keyword themes and their seed keywords.
|
||||
|
||||
Called once on startup — only inserts themes that don't already exist,
|
||||
so user edits are never overwritten.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Default themes + keywords ─────────────────────────────────────────
|
||||
|
||||
DEFAULTS: dict[str, dict] = {
|
||||
"Gambling": {
|
||||
"color": "#f44336",
|
||||
"keywords": [
|
||||
"poker", "casino", "blackjack", "roulette", "sportsbook",
|
||||
"sports betting", "bet365", "draftkings", "fanduel", "bovada",
|
||||
"betonline", "mybookie", "slots", "slot machine", "parlay",
|
||||
"wager", "bookie", "betway", "888casino", "pokerstars",
|
||||
"william hill", "ladbrokes", "betfair", "unibet", "pinnacle",
|
||||
],
|
||||
},
|
||||
"Gaming": {
|
||||
"color": "#9c27b0",
|
||||
"keywords": [
|
||||
"steam", "steamcommunity", "steampowered", "epic games",
|
||||
"epicgames", "origin.com", "battle.net", "blizzard",
|
||||
"roblox", "minecraft", "fortnite", "valorant", "league of legends",
|
||||
"twitch", "twitch.tv", "discord", "discord.gg", "xbox live",
|
||||
"playstation network", "gog.com", "itch.io", "gamepass",
|
||||
"riot games", "ubisoft", "ea.com",
|
||||
],
|
||||
},
|
||||
"Streaming": {
|
||||
"color": "#ff9800",
|
||||
"keywords": [
|
||||
"netflix", "hulu", "disney+", "disneyplus", "hbomax",
|
||||
"amazon prime video", "peacock", "paramount+", "crunchyroll",
|
||||
"funimation", "spotify", "pandora", "soundcloud", "deezer",
|
||||
"tidal", "apple music", "youtube music", "pluto tv",
|
||||
"tubi", "vudu", "plex",
|
||||
],
|
||||
},
|
||||
"Downloads / Piracy": {
|
||||
"color": "#ff5722",
|
||||
"keywords": [
|
||||
"torrent", "bittorrent", "utorrent", "qbittorrent", "piratebay",
|
||||
"thepiratebay", "1337x", "rarbg", "yts", "kickass",
|
||||
"limewire", "frostwire", "mega.nz", "rapidshare", "mediafire",
|
||||
"zippyshare", "uploadhaven", "fitgirl", "repack", "crack",
|
||||
"keygen", "warez", "nulled", "pirate", "magnet:",
|
||||
],
|
||||
},
|
||||
"Adult Content": {
|
||||
"color": "#e91e63",
|
||||
"keywords": [
|
||||
"pornhub", "xvideos", "xhamster", "onlyfans", "chaturbate",
|
||||
"livejasmin", "brazzers", "redtube", "youporn", "xnxx",
|
||||
"porn", "xxx", "nsfw", "adult content", "cam site",
|
||||
"stripchat", "bongacams",
|
||||
],
|
||||
},
|
||||
"Social Media": {
|
||||
"color": "#2196f3",
|
||||
"keywords": [
|
||||
"facebook", "instagram", "tiktok", "snapchat", "pinterest",
|
||||
"reddit", "tumblr", "myspace", "whatsapp web", "telegram web",
|
||||
"signal web", "wechat web", "twitter.com", "x.com",
|
||||
"threads.net", "mastodon", "bluesky",
|
||||
],
|
||||
},
|
||||
"Job Search": {
|
||||
"color": "#4caf50",
|
||||
"keywords": [
|
||||
"indeed", "linkedin jobs", "glassdoor", "monster.com",
|
||||
"ziprecruiter", "careerbuilder", "dice.com", "hired.com",
|
||||
"angel.co", "wellfound", "levels.fyi", "salary.com",
|
||||
"payscale", "resume", "cover letter", "job application",
|
||||
],
|
||||
},
|
||||
"Shopping": {
|
||||
"color": "#00bcd4",
|
||||
"keywords": [
|
||||
"amazon.com", "ebay", "etsy", "walmart.com", "target.com",
|
||||
"bestbuy", "aliexpress", "wish.com", "shein", "temu",
|
||||
"wayfair", "overstock", "newegg", "zappos", "coupon",
|
||||
"promo code", "add to cart",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def seed_defaults(db: AsyncSession) -> int:
|
||||
"""Insert default themes + keywords for any theme name not already in DB.
|
||||
|
||||
Returns the number of themes inserted (0 if all already exist).
|
||||
"""
|
||||
# Rename legacy theme names
|
||||
_renames = [("Social Media (Personal)", "Social Media")]
|
||||
for old_name, new_name in _renames:
|
||||
old = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == old_name))
|
||||
if old:
|
||||
await db.execute(
|
||||
KeywordTheme.__table__.update()
|
||||
.where(KeywordTheme.name == old_name)
|
||||
.values(name=new_name)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info("Renamed AUP theme '%s' → '%s'", old_name, new_name)
|
||||
|
||||
inserted = 0
|
||||
for theme_name, meta in DEFAULTS.items():
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == theme_name)
|
||||
)
|
||||
if exists:
|
||||
continue
|
||||
|
||||
theme = KeywordTheme(
|
||||
name=theme_name,
|
||||
color=meta["color"],
|
||||
enabled=True,
|
||||
is_builtin=True,
|
||||
)
|
||||
db.add(theme)
|
||||
await db.flush() # get theme.id
|
||||
|
||||
for kw in meta["keywords"]:
|
||||
db.add(Keyword(theme_id=theme.id, value=kw))
|
||||
|
||||
inserted += 1
|
||||
logger.info("Seeded AUP theme '%s' with %d keywords", theme_name, len(meta["keywords"]))
|
||||
|
||||
if inserted:
|
||||
await db.commit()
|
||||
logger.info("Seeded %d AUP keyword themes", inserted)
|
||||
else:
|
||||
logger.debug("All default AUP themes already present")
|
||||
|
||||
return inserted
|
||||
196
backend/app/services/normalizer.py
Normal file
196
backend/app/services/normalizer.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Artifact normalizer — maps Velociraptor and common tool columns to canonical schema.
|
||||
|
||||
The canonical schema provides consistent field names regardless of which tool
|
||||
exported the CSV (Velociraptor, OSQuery, Sysmon, etc.).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Column mapping: source_column_pattern → canonical_name ─────────────
|
||||
# Patterns are case-insensitive regexes matched against column names.
|
||||
|
||||
COLUMN_MAPPINGS: list[tuple[str, str]] = [
|
||||
# Timestamps
|
||||
(r"^(timestamp|time|event_?time|date_?time|created?_?(at|time|date)|modified_?(at|time|date)|mtime|ctime|atime|start_?time|end_?time)$", "timestamp"),
|
||||
(r"^(eventtime|system\.timecreated)$", "timestamp"),
|
||||
# Host identifiers
|
||||
(r"^(hostname|host|fqdn|computer_?name|system_?name|machinename|clientid)$", "hostname"),
|
||||
# Operating system
|
||||
(r"^(os|operating_?system|os_?version|os_?name|platform|os_?type)$", "os"),
|
||||
# Source / destination IPs
|
||||
(r"^(source_?ip|src_?ip|srcaddr|local_?address|sourceaddress)$", "src_ip"),
|
||||
(r"^(dest_?ip|dst_?ip|dstaddr|remote_?address|destinationaddress|destaddress)$", "dst_ip"),
|
||||
(r"^(ip_?address|ipaddress|ip)$", "ip_address"),
|
||||
# Ports
|
||||
(r"^(source_?port|src_?port|localport)$", "src_port"),
|
||||
(r"^(dest_?port|dst_?port|remoteport|destinationport)$", "dst_port"),
|
||||
# Process info
|
||||
(r"^(process_?name|name|image|exe|executable|binary)$", "process_name"),
|
||||
(r"^(pid|process_?id)$", "pid"),
|
||||
(r"^(ppid|parent_?pid|parentprocessid)$", "ppid"),
|
||||
(r"^(command_?line|cmdline|commandline|cmd)$", "command_line"),
|
||||
(r"^(parent_?command_?line|parentcommandline)$", "parent_command_line"),
|
||||
# User info
|
||||
(r"^(user|username|user_?name|account_?name|subjectusername)$", "username"),
|
||||
(r"^(user_?id|uid|sid|subjectusersid)$", "user_id"),
|
||||
# File info
|
||||
(r"^(file_?path|fullpath|full_?name|path|filepath)$", "file_path"),
|
||||
(r"^(file_?name|filename|name)$", "file_name"),
|
||||
(r"^(file_?size|size|bytes|length)$", "file_size"),
|
||||
(r"^(extension|file_?ext)$", "file_extension"),
|
||||
# Hashes
|
||||
(r"^(md5|md5hash|hash_?md5)$", "hash_md5"),
|
||||
(r"^(sha1|sha1hash|hash_?sha1)$", "hash_sha1"),
|
||||
(r"^(sha256|sha256hash|hash_?sha256|hash|filehash)$", "hash_sha256"),
|
||||
# Network
|
||||
(r"^(protocol|proto)$", "protocol"),
|
||||
(r"^(domain|dns_?name|query_?name|queriedname)$", "domain"),
|
||||
(r"^(url|uri|request_?url)$", "url"),
|
||||
# Event info
|
||||
(r"^(event_?id|eventid|eid)$", "event_id"),
|
||||
(r"^(event_?type|eventtype|category|action)$", "event_type"),
|
||||
(r"^(description|message|msg|detail)$", "description"),
|
||||
(r"^(severity|level|priority)$", "severity"),
|
||||
# Registry
|
||||
(r"^(reg_?key|registry_?key|targetobject)$", "registry_key"),
|
||||
(r"^(reg_?value|registry_?value)$", "registry_value"),
|
||||
]
|
||||
|
||||
|
||||
def normalize_columns(columns: list[str]) -> dict[str, str]:
|
||||
"""Map raw column names to canonical names.
|
||||
|
||||
Returns:
|
||||
Dict of {raw_column_name: canonical_column_name}.
|
||||
Columns with no match map to themselves (lowered + underscored).
|
||||
"""
|
||||
mapping: dict[str, str] = {}
|
||||
used_canonical: set[str] = set()
|
||||
|
||||
for col in columns:
|
||||
col_lower = col.strip().lower()
|
||||
matched = False
|
||||
for pattern, canonical in COLUMN_MAPPINGS:
|
||||
if re.match(pattern, col_lower, re.IGNORECASE):
|
||||
# Avoid duplicate canonical names
|
||||
if canonical not in used_canonical:
|
||||
mapping[col] = canonical
|
||||
used_canonical.add(canonical)
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
# Produce a clean snake_case version
|
||||
clean = re.sub(r"[^a-z0-9]+", "_", col_lower).strip("_")
|
||||
mapping[col] = clean or col
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def normalize_row(row: dict[str, Any], column_mapping: dict[str, str]) -> dict[str, Any]:
|
||||
"""Apply column mapping to a single row."""
|
||||
return {column_mapping.get(k, k): v for k, v in row.items()}
|
||||
|
||||
|
||||
def normalize_rows(rows: list[dict], column_mapping: dict[str, str]) -> list[dict]:
|
||||
"""Apply column mapping to all rows."""
|
||||
return [normalize_row(row, column_mapping) for row in rows]
|
||||
|
||||
|
||||
def detect_ioc_columns(
|
||||
columns: list[str],
|
||||
column_types: dict[str, str],
|
||||
column_mapping: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
"""Detect which columns contain IOCs (IPs, hashes, domains).
|
||||
|
||||
Returns:
|
||||
Dict of {column_name: ioc_type}.
|
||||
"""
|
||||
ioc_columns: dict[str, str] = {}
|
||||
ioc_type_map = {
|
||||
"ip": "ip",
|
||||
"hash_md5": "hash_md5",
|
||||
"hash_sha1": "hash_sha1",
|
||||
"hash_sha256": "hash_sha256",
|
||||
"domain": "domain",
|
||||
}
|
||||
|
||||
for col in columns:
|
||||
col_type = column_types.get(col)
|
||||
if col_type in ioc_type_map:
|
||||
ioc_columns[col] = ioc_type_map[col_type]
|
||||
|
||||
# Also check canonical name
|
||||
canonical = column_mapping.get(col, "")
|
||||
if canonical in ("src_ip", "dst_ip", "ip_address"):
|
||||
ioc_columns[col] = "ip"
|
||||
elif canonical == "hash_md5":
|
||||
ioc_columns[col] = "hash_md5"
|
||||
elif canonical == "hash_sha1":
|
||||
ioc_columns[col] = "hash_sha1"
|
||||
elif canonical in ("hash_sha256",):
|
||||
ioc_columns[col] = "hash_sha256"
|
||||
elif canonical == "domain":
|
||||
ioc_columns[col] = "domain"
|
||||
elif canonical == "url":
|
||||
ioc_columns[col] = "url"
|
||||
|
||||
return ioc_columns
|
||||
|
||||
|
||||
def detect_time_range(
|
||||
rows: list[dict],
|
||||
column_mapping: dict[str, str],
|
||||
) -> tuple[datetime | None, datetime | None]:
|
||||
"""Find the earliest and latest timestamps in the dataset."""
|
||||
ts_col = None
|
||||
for raw_col, canonical in column_mapping.items():
|
||||
if canonical == "timestamp":
|
||||
ts_col = raw_col
|
||||
break
|
||||
|
||||
if not ts_col:
|
||||
return None, None
|
||||
|
||||
timestamps: list[datetime] = []
|
||||
for row in rows:
|
||||
val = row.get(ts_col)
|
||||
if not val:
|
||||
continue
|
||||
try:
|
||||
dt = _parse_timestamp(str(val))
|
||||
if dt:
|
||||
timestamps.append(dt)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
if not timestamps:
|
||||
return None, None
|
||||
|
||||
return min(timestamps), max(timestamps)
|
||||
|
||||
|
||||
def _parse_timestamp(value: str) -> datetime | None:
|
||||
"""Try multiple timestamp formats."""
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ",
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%m/%d/%Y %H:%M:%S",
|
||||
"%d/%m/%Y %H:%M:%S",
|
||||
]
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(value.strip(), fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
425
backend/app/services/reports.py
Normal file
425
backend/app/services/reports.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""Report generation — JSON, HTML, and CSV export for hunt investigations.
|
||||
|
||||
Generates comprehensive investigation reports including:
|
||||
- Hunt metadata and status
|
||||
- Dataset summaries with IOC counts
|
||||
- Hypotheses and their evidence
|
||||
- Annotations timeline
|
||||
- Enrichment verdicts
|
||||
- Agent conversation history
|
||||
- Cross-hunt correlations
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
Hunt, Dataset, DatasetRow, Hypothesis,
|
||||
Annotation, Conversation, Message, EnrichmentResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
"""Generates exportable investigation reports."""
|
||||
|
||||
async def generate_hunt_report(
|
||||
self,
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
format: str = "json",
|
||||
include_rows: bool = False,
|
||||
max_rows: int = 500,
|
||||
) -> dict | str:
|
||||
"""Generate a comprehensive report for a hunt investigation."""
|
||||
|
||||
# Gather all hunt data
|
||||
report_data = await self._gather_hunt_data(
|
||||
hunt_id, db, include_rows=include_rows, max_rows=max_rows,
|
||||
)
|
||||
|
||||
if not report_data:
|
||||
return {"error": "Hunt not found"}
|
||||
|
||||
if format == "json":
|
||||
return report_data
|
||||
elif format == "html":
|
||||
return self._render_html(report_data)
|
||||
elif format == "csv":
|
||||
return self._render_csv(report_data)
|
||||
else:
|
||||
return report_data
|
||||
|
||||
async def _gather_hunt_data(
|
||||
self,
|
||||
hunt_id: str,
|
||||
db: AsyncSession,
|
||||
include_rows: bool = False,
|
||||
max_rows: int = 500,
|
||||
) -> dict | None:
|
||||
"""Gather all data for a hunt report."""
|
||||
|
||||
# Hunt metadata
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
return None
|
||||
|
||||
# Datasets
|
||||
ds_result = await db.execute(
|
||||
select(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_result.scalars().all()
|
||||
|
||||
dataset_summaries = []
|
||||
all_iocs = {}
|
||||
for ds in datasets:
|
||||
summary = {
|
||||
"id": ds.id,
|
||||
"name": ds.name,
|
||||
"filename": ds.filename,
|
||||
"source_tool": ds.source_tool,
|
||||
"row_count": ds.row_count,
|
||||
"columns": list((ds.column_schema or {}).keys()),
|
||||
"ioc_columns": ds.ioc_columns or {},
|
||||
"time_range": {
|
||||
"start": ds.time_range_start,
|
||||
"end": ds.time_range_end,
|
||||
},
|
||||
"created_at": ds.created_at.isoformat(),
|
||||
}
|
||||
|
||||
if include_rows:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(max_rows)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
summary["rows"] = [r.data for r in rows]
|
||||
|
||||
dataset_summaries.append(summary)
|
||||
|
||||
# Collect IOCs for enrichment lookup
|
||||
if ds.ioc_columns:
|
||||
all_iocs.update(ds.ioc_columns)
|
||||
|
||||
# Hypotheses
|
||||
hyp_result = await db.execute(
|
||||
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||
)
|
||||
hypotheses = hyp_result.scalars().all()
|
||||
|
||||
hypotheses_data = [
|
||||
{
|
||||
"id": h.id,
|
||||
"title": h.title,
|
||||
"description": h.description,
|
||||
"mitre_technique": h.mitre_technique,
|
||||
"status": h.status,
|
||||
"evidence_row_ids": h.evidence_row_ids,
|
||||
"evidence_notes": h.evidence_notes,
|
||||
"created_at": h.created_at.isoformat(),
|
||||
"updated_at": h.updated_at.isoformat(),
|
||||
}
|
||||
for h in hypotheses
|
||||
]
|
||||
|
||||
# Annotations (across all datasets in this hunt)
|
||||
dataset_ids = [ds.id for ds in datasets]
|
||||
annotations_data = []
|
||||
if dataset_ids:
|
||||
ann_result = await db.execute(
|
||||
select(Annotation)
|
||||
.where(Annotation.dataset_id.in_(dataset_ids))
|
||||
.order_by(Annotation.created_at)
|
||||
)
|
||||
annotations = ann_result.scalars().all()
|
||||
annotations_data = [
|
||||
{
|
||||
"id": a.id,
|
||||
"dataset_id": a.dataset_id,
|
||||
"row_id": a.row_id,
|
||||
"text": a.text,
|
||||
"severity": a.severity,
|
||||
"tag": a.tag,
|
||||
"created_at": a.created_at.isoformat(),
|
||||
}
|
||||
for a in annotations
|
||||
]
|
||||
|
||||
# Conversations
|
||||
conv_result = await db.execute(
|
||||
select(Conversation).where(Conversation.hunt_id == hunt_id)
|
||||
)
|
||||
conversations = conv_result.scalars().all()
|
||||
|
||||
conversations_data = []
|
||||
for conv in conversations:
|
||||
msg_result = await db.execute(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conv.id)
|
||||
.order_by(Message.created_at)
|
||||
)
|
||||
messages = msg_result.scalars().all()
|
||||
conversations_data.append({
|
||||
"id": conv.id,
|
||||
"title": conv.title,
|
||||
"messages": [
|
||||
{
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"model_used": m.model_used,
|
||||
"node_used": m.node_used,
|
||||
"latency_ms": m.latency_ms,
|
||||
"created_at": m.created_at.isoformat(),
|
||||
}
|
||||
for m in messages
|
||||
],
|
||||
})
|
||||
|
||||
# Enrichment results
|
||||
enrichment_data = []
|
||||
for ds in datasets:
|
||||
if not ds.ioc_columns:
|
||||
continue
|
||||
# Get unique enriched IOCs for this dataset
|
||||
for col_name in ds.ioc_columns.keys():
|
||||
enrich_result = await db.execute(
|
||||
select(EnrichmentResult)
|
||||
.where(EnrichmentResult.source.isnot(None))
|
||||
.limit(100)
|
||||
)
|
||||
enrichments = enrich_result.scalars().all()
|
||||
for e in enrichments:
|
||||
enrichment_data.append({
|
||||
"ioc_value": e.ioc_value,
|
||||
"ioc_type": e.ioc_type,
|
||||
"source": e.source,
|
||||
"verdict": e.verdict,
|
||||
"score": e.score,
|
||||
"tags": e.tags,
|
||||
"country": e.country,
|
||||
})
|
||||
break # Only query once
|
||||
|
||||
# Build report
|
||||
now = datetime.now(timezone.utc)
|
||||
return {
|
||||
"report_metadata": {
|
||||
"generated_at": now.isoformat(),
|
||||
"format_version": "1.0",
|
||||
"generator": "ThreatHunt Report Engine",
|
||||
},
|
||||
"hunt": {
|
||||
"id": hunt.id,
|
||||
"name": hunt.name,
|
||||
"description": hunt.description,
|
||||
"status": hunt.status,
|
||||
"created_at": hunt.created_at.isoformat(),
|
||||
"updated_at": hunt.updated_at.isoformat(),
|
||||
},
|
||||
"summary": {
|
||||
"dataset_count": len(datasets),
|
||||
"total_rows": sum(ds.row_count for ds in datasets),
|
||||
"hypothesis_count": len(hypotheses),
|
||||
"confirmed_hypotheses": len([h for h in hypotheses if h.status == "confirmed"]),
|
||||
"annotation_count": len(annotations_data),
|
||||
"critical_annotations": len([a for a in annotations_data if a["severity"] == "critical"]),
|
||||
"conversation_count": len(conversations_data),
|
||||
"enrichment_count": len(enrichment_data),
|
||||
"malicious_iocs": len([e for e in enrichment_data if e["verdict"] == "malicious"]),
|
||||
},
|
||||
"datasets": dataset_summaries,
|
||||
"hypotheses": hypotheses_data,
|
||||
"annotations": annotations_data,
|
||||
"conversations": conversations_data,
|
||||
"enrichments": enrichment_data[:100],
|
||||
}
|
||||
|
||||
def _render_html(self, data: dict) -> str:
|
||||
"""Render report as self-contained HTML."""
|
||||
hunt = data.get("hunt", {})
|
||||
summary = data.get("summary", {})
|
||||
hypotheses = data.get("hypotheses", [])
|
||||
annotations = data.get("annotations", [])
|
||||
datasets = data.get("datasets", [])
|
||||
enrichments = data.get("enrichments", [])
|
||||
meta = data.get("report_metadata", {})
|
||||
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>ThreatHunt Report: {hunt.get('name', 'Unknown')}</title>
|
||||
<style>
|
||||
:root {{ --bg: #0d1117; --surface: #161b22; --border: #30363d; --text: #c9d1d9; --accent: #58a6ff; --red: #f85149; --orange: #d29922; --green: #3fb950; }}
|
||||
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif; background: var(--bg); color: var(--text); line-height: 1.6; padding: 2rem; }}
|
||||
.container {{ max-width: 1200px; margin: 0 auto; }}
|
||||
h1 {{ color: var(--accent); border-bottom: 2px solid var(--border); padding-bottom: 0.5rem; margin-bottom: 1rem; }}
|
||||
h2 {{ color: var(--accent); margin: 1.5rem 0 0.75rem; }}
|
||||
h3 {{ color: var(--text); margin: 1rem 0 0.5rem; }}
|
||||
.card {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; margin: 0.75rem 0; }}
|
||||
.stat-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); gap: 0.75rem; }}
|
||||
.stat {{ background: var(--surface); border: 1px solid var(--border); border-radius: 8px; padding: 1rem; text-align: center; }}
|
||||
.stat .value {{ font-size: 2rem; font-weight: 700; color: var(--accent); }}
|
||||
.stat .label {{ font-size: 0.85rem; color: #8b949e; }}
|
||||
table {{ width: 100%; border-collapse: collapse; margin: 0.5rem 0; }}
|
||||
th, td {{ padding: 0.5rem 0.75rem; border: 1px solid var(--border); text-align: left; }}
|
||||
th {{ background: var(--surface); color: var(--accent); }}
|
||||
.badge {{ display: inline-block; padding: 0.15rem 0.5rem; border-radius: 999px; font-size: 0.8rem; font-weight: 600; }}
|
||||
.badge-malicious {{ background: var(--red); color: white; }}
|
||||
.badge-suspicious {{ background: var(--orange); color: #000; }}
|
||||
.badge-clean {{ background: var(--green); color: #000; }}
|
||||
.badge-critical {{ background: var(--red); color: white; }}
|
||||
.badge-high {{ background: #da3633; color: white; }}
|
||||
.badge-medium {{ background: var(--orange); color: #000; }}
|
||||
.badge-confirmed {{ background: var(--green); color: #000; }}
|
||||
.badge-active {{ background: var(--accent); color: #000; }}
|
||||
.footer {{ margin-top: 2rem; padding-top: 1rem; border-top: 1px solid var(--border); color: #8b949e; font-size: 0.85rem; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔍 ThreatHunt Report: {hunt.get('name', 'Untitled')}</h1>
|
||||
<p><strong>Hunt ID:</strong> {hunt.get('id', '')}<br>
|
||||
<strong>Status:</strong> {hunt.get('status', 'unknown')}<br>
|
||||
<strong>Description:</strong> {hunt.get('description', 'N/A')}<br>
|
||||
<strong>Created:</strong> {hunt.get('created_at', '')}</p>
|
||||
|
||||
<h2>Summary</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat"><div class="value">{summary.get('dataset_count', 0)}</div><div class="label">Datasets</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('total_rows', 0):,}</div><div class="label">Total Rows</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('hypothesis_count', 0)}</div><div class="label">Hypotheses</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('confirmed_hypotheses', 0)}</div><div class="label">Confirmed</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('annotation_count', 0)}</div><div class="label">Annotations</div></div>
|
||||
<div class="stat"><div class="value">{summary.get('malicious_iocs', 0)}</div><div class="label">Malicious IOCs</div></div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Hypotheses section
|
||||
if hypotheses:
|
||||
html += "<h2>Hypotheses</h2>\n"
|
||||
html += "<table><tr><th>Title</th><th>MITRE</th><th>Status</th><th>Description</th></tr>\n"
|
||||
for h in hypotheses:
|
||||
status_class = f"badge-{h['status']}" if h['status'] in ('confirmed', 'active') else ""
|
||||
html += (
|
||||
f"<tr><td>{h['title']}</td>"
|
||||
f"<td>{h.get('mitre_technique', 'N/A')}</td>"
|
||||
f"<td><span class='badge {status_class}'>{h['status']}</span></td>"
|
||||
f"<td>{h.get('description', '') or ''}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
# Datasets section
|
||||
if datasets:
|
||||
html += "<h2>Datasets</h2>\n"
|
||||
for ds in datasets:
|
||||
html += f"""<div class="card">
|
||||
<h3>{ds['name']} ({ds.get('filename', '')})</h3>
|
||||
<p><strong>Source:</strong> {ds.get('source_tool', 'N/A')} |
|
||||
<strong>Rows:</strong> {ds['row_count']:,} |
|
||||
<strong>IOC Columns:</strong> {len(ds.get('ioc_columns', {}))} |
|
||||
<strong>Time Range:</strong> {ds.get('time_range', {}).get('start', 'N/A')} to {ds.get('time_range', {}).get('end', 'N/A')}</p>
|
||||
</div>\n"""
|
||||
|
||||
# Annotations
|
||||
if annotations:
|
||||
critical = [a for a in annotations if a['severity'] in ('critical', 'high')]
|
||||
html += f"<h2>Annotations ({len(annotations)} total, {len(critical)} critical/high)</h2>\n"
|
||||
html += "<table><tr><th>Severity</th><th>Tag</th><th>Text</th><th>Created</th></tr>\n"
|
||||
for a in annotations[:50]:
|
||||
sev_class = f"badge-{a['severity']}" if a['severity'] in ('critical', 'high', 'medium') else ""
|
||||
html += (
|
||||
f"<tr><td><span class='badge {sev_class}'>{a['severity']}</span></td>"
|
||||
f"<td>{a.get('tag', 'N/A')}</td>"
|
||||
f"<td>{a['text'][:200]}</td>"
|
||||
f"<td>{a['created_at'][:19]}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
# Enrichments
|
||||
if enrichments:
|
||||
malicious = [e for e in enrichments if e['verdict'] == 'malicious']
|
||||
html += f"<h2>IOC Enrichment ({len(enrichments)} results, {len(malicious)} malicious)</h2>\n"
|
||||
html += "<table><tr><th>IOC</th><th>Type</th><th>Source</th><th>Verdict</th><th>Score</th></tr>\n"
|
||||
for e in enrichments[:50]:
|
||||
verdict_class = f"badge-{e['verdict']}"
|
||||
html += (
|
||||
f"<tr><td><code>{e['ioc_value']}</code></td>"
|
||||
f"<td>{e['ioc_type']}</td>"
|
||||
f"<td>{e['source']}</td>"
|
||||
f"<td><span class='badge {verdict_class}'>{e['verdict']}</span></td>"
|
||||
f"<td>{e.get('score', 0)}</td></tr>\n"
|
||||
)
|
||||
html += "</table>\n"
|
||||
|
||||
html += f"""
|
||||
<div class="footer">
|
||||
<p>Generated by ThreatHunt Report Engine | {meta.get('generated_at', '')[:19]}</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
return html
|
||||
|
||||
def _render_csv(self, data: dict) -> str:
|
||||
"""Render key report data as CSV."""
|
||||
output = io.StringIO()
|
||||
|
||||
# Hypotheses sheet
|
||||
output.write("=== HYPOTHESES ===\n")
|
||||
writer = csv.writer(output)
|
||||
writer.writerow(["Title", "MITRE Technique", "Status", "Description", "Evidence Notes"])
|
||||
for h in data.get("hypotheses", []):
|
||||
writer.writerow([
|
||||
h.get("title", ""),
|
||||
h.get("mitre_technique", ""),
|
||||
h.get("status", ""),
|
||||
h.get("description", ""),
|
||||
h.get("evidence_notes", ""),
|
||||
])
|
||||
|
||||
output.write("\n=== ANNOTATIONS ===\n")
|
||||
writer.writerow(["Severity", "Tag", "Text", "Dataset ID", "Row ID", "Created"])
|
||||
for a in data.get("annotations", []):
|
||||
writer.writerow([
|
||||
a.get("severity", ""),
|
||||
a.get("tag", ""),
|
||||
a.get("text", ""),
|
||||
a.get("dataset_id", ""),
|
||||
a.get("row_id", ""),
|
||||
a.get("created_at", ""),
|
||||
])
|
||||
|
||||
output.write("\n=== ENRICHMENTS ===\n")
|
||||
writer.writerow(["IOC Value", "IOC Type", "Source", "Verdict", "Score", "Country"])
|
||||
for e in data.get("enrichments", []):
|
||||
writer.writerow([
|
||||
e.get("ioc_value", ""),
|
||||
e.get("ioc_type", ""),
|
||||
e.get("source", ""),
|
||||
e.get("verdict", ""),
|
||||
e.get("score", ""),
|
||||
e.get("country", ""),
|
||||
])
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
# Singleton
|
||||
report_generator = ReportGenerator()
|
||||
346
backend/app/services/sans_rag.py
Normal file
346
backend/app/services/sans_rag.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""SANS RAG service — queries the 300GB SANS courseware indexed in Open WebUI.
|
||||
|
||||
Provides contextual SANS references for threat hunting guidance.
|
||||
Uses two approaches:
|
||||
1. Open WebUI RAG pipeline (if configured with a knowledge collection)
|
||||
2. Embedding-based semantic search against locally indexed SANS content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from app.agents.providers_v2 import _get_client
|
||||
from app.agents.registry import Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── SANS course catalog for reference matching ────────────────────────
|
||||
|
||||
SANS_COURSES = {
|
||||
"SEC401": "Security Essentials",
|
||||
"SEC504": "Hacker Tools, Techniques, and Incident Handling",
|
||||
"SEC503": "Network Monitoring and Threat Detection In-Depth",
|
||||
"SEC505": "Securing Windows and PowerShell Automation",
|
||||
"SEC506": "Securing Linux/Unix",
|
||||
"SEC510": "Public Cloud Security: AWS, Azure, and GCP",
|
||||
"SEC511": "Continuous Monitoring and Security Operations",
|
||||
"SEC530": "Defensible Security Architecture and Engineering",
|
||||
"SEC540": "Cloud Security and DevSecOps Automation",
|
||||
"SEC555": "SIEM with Tactical Analytics",
|
||||
"SEC560": "Enterprise Penetration Testing",
|
||||
"SEC565": "Red Team Operations and Adversary Emulation",
|
||||
"SEC573": "Automating Information Security with Python",
|
||||
"SEC575": "Mobile Device Security and Ethical Hacking",
|
||||
"SEC588": "Cloud Penetration Testing",
|
||||
"SEC599": "Defeating Advanced Adversaries - Purple Team Tactics",
|
||||
"FOR408": "Windows Forensic Analysis",
|
||||
"FOR498": "Digital Acquisition and Rapid Triage",
|
||||
"FOR500": "Windows Forensic Analysis",
|
||||
"FOR508": "Advanced Incident Response, Threat Hunting, and Digital Forensics",
|
||||
"FOR509": "Enterprise Cloud Forensics and Incident Response",
|
||||
"FOR518": "Mac and iOS Forensic Analysis and Incident Response",
|
||||
"FOR572": "Advanced Network Forensics: Threat Hunting, Analysis, and Incident Response",
|
||||
"FOR578": "Cyber Threat Intelligence",
|
||||
"FOR585": "Smartphone Forensic Analysis In-Depth",
|
||||
"FOR610": "Reverse-Engineering Malware: Malware Analysis Tools and Techniques",
|
||||
"FOR710": "Reverse-Engineering Malware: Advanced Code Analysis",
|
||||
"ICS410": "ICS/SCADA Security Essentials",
|
||||
"ICS515": "ICS Visibility, Detection, and Response",
|
||||
}
|
||||
|
||||
# Topic-to-course mapping for fallback recommendations
|
||||
TOPIC_COURSE_MAP = {
|
||||
"malware": ["FOR610", "FOR710", "SEC504"],
|
||||
"reverse engineer": ["FOR610", "FOR710"],
|
||||
"incident response": ["FOR508", "SEC504"],
|
||||
"forensic": ["FOR508", "FOR500", "FOR408"],
|
||||
"windows forensic": ["FOR500", "FOR408"],
|
||||
"network forensic": ["FOR572"],
|
||||
"threat hunting": ["FOR508", "SEC504", "FOR578"],
|
||||
"threat intelligence": ["FOR578"],
|
||||
"powershell": ["SEC505", "FOR508"],
|
||||
"lateral movement": ["SEC504", "FOR508"],
|
||||
"persistence": ["FOR508", "SEC504"],
|
||||
"privilege escalation": ["SEC504", "SEC560"],
|
||||
"credential": ["SEC504", "SEC560"],
|
||||
"memory forensic": ["FOR508"],
|
||||
"disk forensic": ["FOR500", "FOR408"],
|
||||
"registry": ["FOR500", "FOR408"],
|
||||
"event log": ["FOR508", "SEC555"],
|
||||
"siem": ["SEC555"],
|
||||
"log analysis": ["SEC555", "SEC503"],
|
||||
"network monitor": ["SEC503"],
|
||||
"pcap": ["SEC503", "FOR572"],
|
||||
"cloud": ["SEC510", "SEC540", "FOR509"],
|
||||
"aws": ["SEC510", "SEC540", "FOR509"],
|
||||
"azure": ["SEC510", "FOR509"],
|
||||
"linux": ["SEC506"],
|
||||
"mobile": ["SEC575", "FOR585"],
|
||||
"penetration test": ["SEC560", "SEC565"],
|
||||
"red team": ["SEC565", "SEC599"],
|
||||
"purple team": ["SEC599"],
|
||||
"python": ["SEC573"],
|
||||
"automation": ["SEC573", "SEC540"],
|
||||
"deobfusc": ["FOR610", "SEC504"],
|
||||
"base64": ["FOR610", "SEC504"],
|
||||
"shellcode": ["FOR610", "FOR710"],
|
||||
"ransomware": ["FOR508", "FOR610"],
|
||||
"phishing": ["SEC504", "FOR578"],
|
||||
"c2": ["FOR508", "SEC504", "FOR572"],
|
||||
"command and control": ["FOR508", "SEC504"],
|
||||
"exfiltration": ["FOR508", "FOR572", "SEC503"],
|
||||
"dns": ["FOR572", "SEC503"],
|
||||
"ioc": ["FOR508", "FOR578"],
|
||||
"mitre": ["FOR508", "SEC504", "SEC599"],
|
||||
"att&ck": ["FOR508", "SEC504"],
|
||||
"velociraptor": ["FOR508"],
|
||||
"volatility": ["FOR508"],
|
||||
"scheduled task": ["FOR508", "SEC504"],
|
||||
"service": ["FOR508", "SEC504"],
|
||||
"wmi": ["FOR508", "SEC504"],
|
||||
"process": ["FOR508"],
|
||||
"dll": ["FOR610", "FOR508"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGResult:
|
||||
"""Result from a RAG query."""
|
||||
query: str
|
||||
context: str # Retrieved relevant text
|
||||
sources: list[str] = field(default_factory=list) # Source document names
|
||||
course_references: list[str] = field(default_factory=list) # SANS course IDs
|
||||
confidence: float = 0.0
|
||||
latency_ms: int = 0
|
||||
|
||||
|
||||
class SANSRAGService:
|
||||
"""Service for querying SANS courseware via Open WebUI RAG pipeline."""
|
||||
|
||||
def __init__(self):
|
||||
self.openwebui_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
self.rag_model = settings.DEFAULT_FAST_MODEL
|
||||
self._available: bool | None = None
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def query(
|
||||
self,
|
||||
question: str,
|
||||
context: str = "",
|
||||
max_tokens: int = 1024,
|
||||
) -> RAGResult:
|
||||
"""Query SANS courseware for relevant context.
|
||||
|
||||
Uses Open WebUI's RAG-enabled chat to retrieve from indexed SANS content.
|
||||
Falls back to topic-based course recommendations if RAG is unavailable.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
# Try Open WebUI RAG pipeline first
|
||||
try:
|
||||
result = await self._query_openwebui_rag(question, context, max_tokens)
|
||||
result.latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Enrich with course references if not already present
|
||||
if not result.course_references:
|
||||
result.course_references = self._match_courses(question)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG query failed, using fallback: {e}")
|
||||
# Fallback to topic-based matching
|
||||
courses = self._match_courses(question)
|
||||
return RAGResult(
|
||||
query=question,
|
||||
context="",
|
||||
sources=[],
|
||||
course_references=courses,
|
||||
confidence=0.3 if courses else 0.0,
|
||||
latency_ms=int((time.monotonic() - start) * 1000),
|
||||
)
|
||||
|
||||
async def _query_openwebui_rag(
|
||||
self,
|
||||
question: str,
|
||||
context: str,
|
||||
max_tokens: int,
|
||||
) -> RAGResult:
|
||||
"""Query Open WebUI with RAG context retrieval.
|
||||
|
||||
Open WebUI automatically retrieves from its indexed knowledge base
|
||||
when the model is configured with a knowledge collection.
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
system_msg = (
|
||||
"You are a SANS cybersecurity knowledge assistant. "
|
||||
"Use your indexed SANS courseware to answer the question. "
|
||||
"Always cite the specific SANS course (e.g., FOR508, SEC504) "
|
||||
"and relevant section when referencing material. "
|
||||
"If the question relates to threat hunting procedures, "
|
||||
"reference the specific SANS methodology or framework."
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_msg},
|
||||
]
|
||||
|
||||
if context:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"Investigation context:\n{context}\n\nQuestion: {question}",
|
||||
})
|
||||
else:
|
||||
messages.append({"role": "user", "content": question})
|
||||
|
||||
payload = {
|
||||
"model": self.rag_model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.2,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
resp = await client.post(
|
||||
f"{self.openwebui_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
# Extract course references from response
|
||||
course_refs = self._extract_course_refs(content)
|
||||
sources = self._extract_sources(data)
|
||||
|
||||
return RAGResult(
|
||||
query=question,
|
||||
context=content,
|
||||
sources=sources,
|
||||
course_references=course_refs,
|
||||
confidence=0.8 if content else 0.0,
|
||||
)
|
||||
|
||||
def _extract_course_refs(self, text: str) -> list[str]:
|
||||
"""Extract SANS course references from response text."""
|
||||
refs = set()
|
||||
# Match patterns like SEC504, FOR508, ICS410
|
||||
pattern = r'\b(SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
# Need to get the full match
|
||||
full_pattern = r'\b(?:SEC|FOR|ICS|MGT|AUD|DEV|LEG)\d{3}\b'
|
||||
full_matches = re.findall(full_pattern, text, re.IGNORECASE)
|
||||
for m in full_matches:
|
||||
course_id = m.upper()
|
||||
if course_id in SANS_COURSES:
|
||||
refs.add(f"{course_id}: {SANS_COURSES[course_id]}")
|
||||
else:
|
||||
refs.add(course_id)
|
||||
return sorted(refs)
|
||||
|
||||
def _extract_sources(self, api_response: dict) -> list[str]:
|
||||
"""Extract source document references from Open WebUI response metadata."""
|
||||
sources = []
|
||||
# Open WebUI may include source metadata in various formats
|
||||
if "sources" in api_response:
|
||||
for src in api_response["sources"]:
|
||||
if isinstance(src, dict):
|
||||
sources.append(src.get("name", src.get("title", str(src))))
|
||||
else:
|
||||
sources.append(str(src))
|
||||
# Check in metadata
|
||||
for choice in api_response.get("choices", []):
|
||||
meta = choice.get("metadata", {})
|
||||
if "sources" in meta:
|
||||
for src in meta["sources"]:
|
||||
if isinstance(src, dict):
|
||||
sources.append(src.get("name", str(src)))
|
||||
else:
|
||||
sources.append(str(src))
|
||||
return sources[:10] # Limit
|
||||
|
||||
def _match_courses(self, query: str) -> list[str]:
|
||||
"""Match query keywords to SANS courses using topic map."""
|
||||
q = query.lower()
|
||||
matched = set()
|
||||
for topic, courses in TOPIC_COURSE_MAP.items():
|
||||
if topic in q:
|
||||
for course_id in courses:
|
||||
if course_id in SANS_COURSES:
|
||||
matched.add(f"{course_id}: {SANS_COURSES[course_id]}")
|
||||
return sorted(matched)[:5]
|
||||
|
||||
async def get_course_context(self, course_id: str) -> str:
|
||||
"""Get a brief course description for context injection."""
|
||||
course_id = course_id.upper().split(":")[0].strip()
|
||||
if course_id in SANS_COURSES:
|
||||
return f"{course_id}: {SANS_COURSES[course_id]}"
|
||||
return ""
|
||||
|
||||
async def enrich_prompt(
|
||||
self,
|
||||
query: str,
|
||||
investigation_context: str = "",
|
||||
) -> str:
|
||||
"""Generate SANS-enriched context to inject into agent prompts.
|
||||
|
||||
Returns a context string with relevant SANS references.
|
||||
"""
|
||||
result = await self.query(query, context=investigation_context, max_tokens=512)
|
||||
|
||||
parts = []
|
||||
if result.context:
|
||||
parts.append(f"SANS Reference Context:\n{result.context}")
|
||||
if result.course_references:
|
||||
parts.append(f"Relevant SANS Courses: {', '.join(result.course_references)}")
|
||||
if result.sources:
|
||||
parts.append(f"Sources: {', '.join(result.sources[:5])}")
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
async def health_check(self) -> dict:
|
||||
"""Check RAG service availability."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.openwebui_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
available = resp.status_code == 200
|
||||
self._available = available
|
||||
return {
|
||||
"available": available,
|
||||
"url": self.openwebui_url,
|
||||
"model": self.rag_model,
|
||||
}
|
||||
except Exception as e:
|
||||
self._available = False
|
||||
return {
|
||||
"available": False,
|
||||
"url": self.openwebui_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
# Singleton
|
||||
sans_rag = SANSRAGService()
|
||||
233
backend/app/services/scanner.py
Normal file
233
backend/app/services/scanner.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
Keyword,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanHit:
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
total_hits: int = 0
|
||||
hits: list[ScanHit] = field(default_factory=list)
|
||||
themes_scanned: int = 0
|
||||
keywords_scanned: int = 0
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = True,
|
||||
scan_annotations: bool = True,
|
||||
scan_messages: bool = True,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
themes = await self._load_themes(theme_ids)
|
||||
if not themes:
|
||||
return ScanResult().__dict__
|
||||
|
||||
# Pre-compile patterns per theme
|
||||
patterns = self._compile_patterns(themes)
|
||||
result = ScanResult(
|
||||
themes_scanned=len(themes),
|
||||
keywords_scanned=sum(len(kws) for kws in patterns.values()),
|
||||
)
|
||||
|
||||
# Scan dataset rows
|
||||
await self._scan_datasets(patterns, result, dataset_ids)
|
||||
|
||||
# Scan hunts
|
||||
if scan_hunts:
|
||||
await self._scan_hunts(patterns, result)
|
||||
|
||||
# Scan annotations
|
||||
if scan_annotations:
|
||||
await self._scan_annotations(patterns, result)
|
||||
|
||||
# Scan messages
|
||||
if scan_messages:
|
||||
await self._scan_messages(patterns, result)
|
||||
|
||||
result.total_hits = len(result.hits)
|
||||
return {
|
||||
"total_hits": result.total_hits,
|
||||
"hits": [h.__dict__ for h in result.hits],
|
||||
"themes_scanned": result.themes_scanned,
|
||||
"keywords_scanned": result.keywords_scanned,
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
if theme_ids:
|
||||
q = q.where(KeywordTheme.id.in_(theme_ids))
|
||||
result = await self.db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
def _compile_patterns(
|
||||
self, themes: list[KeywordTheme]
|
||||
) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]:
|
||||
"""Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}"""
|
||||
patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {}
|
||||
for theme in themes:
|
||||
key = (theme.id, theme.name, theme.color)
|
||||
compiled = []
|
||||
for kw in theme.keywords:
|
||||
try:
|
||||
if kw.is_regex:
|
||||
pat = re.compile(kw.value, re.IGNORECASE)
|
||||
else:
|
||||
pat = re.compile(re.escape(kw.value), re.IGNORECASE)
|
||||
compiled.append((kw.value, pat))
|
||||
except re.error:
|
||||
logger.warning("Invalid regex pattern '%s' in theme '%s', skipping",
|
||||
kw.value, theme.name)
|
||||
patterns[key] = compiled
|
||||
return patterns
|
||||
|
||||
def _match_text(
|
||||
self,
|
||||
text: str,
|
||||
patterns: dict,
|
||||
source_type: str,
|
||||
source_id: str | int,
|
||||
field_name: str,
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
return
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
# Truncate matched_value for display
|
||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
keyword=kw_value,
|
||||
source_type=source_type,
|
||||
source_id=source_id,
|
||||
field=field_name,
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
# Build dataset name lookup
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
ds_result = await self.db.execute(ds_q)
|
||||
ds_map = {r[0]: r[1] for r in ds_result.fetchall()}
|
||||
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
# Iterate rows in batches
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
hunts_result = await self.db.execute(select(Hunt))
|
||||
for hunt in hunts_result.scalars().all():
|
||||
self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits)
|
||||
if hunt.description:
|
||||
self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits)
|
||||
|
||||
async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan annotation text."""
|
||||
ann_result = await self.db.execute(select(Annotation))
|
||||
for ann in ann_result.scalars().all():
|
||||
self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits)
|
||||
|
||||
async def _scan_messages(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan conversation messages (user messages only)."""
|
||||
msg_result = await self.db.execute(
|
||||
select(Message).where(Message.role == "user")
|
||||
)
|
||||
for msg in msg_result.scalars().all():
|
||||
self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)
|
||||
12
backend/pyproject.toml
Normal file
12
backend/pyproject.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
addopts = "-v --tb=short"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["app"]
|
||||
omit = ["app/agent/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
@@ -1,21 +1,29 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
# ── Core ──────────────────────────────────────
|
||||
fastapi>=0.104.1
|
||||
uvicorn[standard]>=0.24.0
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
|
||||
# Optional LLM provider dependencies
|
||||
# Uncomment based on your deployment choice:
|
||||
# ── Database ──────────────────────────────────
|
||||
sqlalchemy>=2.0.23
|
||||
alembic>=1.13.0
|
||||
aiosqlite>=0.19.0
|
||||
# asyncpg>=0.29.0 # uncomment for PostgreSQL in production
|
||||
|
||||
# For local models (GGML, Ollama, etc.)
|
||||
# llama-cpp-python==0.2.15
|
||||
# ollama==0.0.11
|
||||
# ── HTTP / LLM ───────────────────────────────
|
||||
httpx>=0.25.1
|
||||
|
||||
# For online providers (OpenAI, Anthropic, Google)
|
||||
# openai==1.3.5
|
||||
# anthropic==0.7.1
|
||||
# google-generativeai==0.3.0
|
||||
# ── CSV / File handling ──────────────────────
|
||||
chardet>=5.2.0
|
||||
python-multipart>=0.0.6
|
||||
|
||||
# ── Auth / Security ──────────────────────────
|
||||
python-jose[cryptography]>=3.3.0
|
||||
passlib[bcrypt]>=1.7.4
|
||||
bcrypt>=4.0.0
|
||||
|
||||
# ── Development / Testing ────────────────────
|
||||
pytest>=7.4.3
|
||||
pytest-asyncio>=0.21.1
|
||||
coverage>=7.3.0
|
||||
|
||||
# For development
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
httpx==0.25.1
|
||||
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
108
backend/tests/conftest.py
Normal file
108
backend/tests/conftest.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Shared pytest fixtures for ThreatHunt tests.
|
||||
|
||||
Provides:
|
||||
- Async test database (in-memory SQLite)
|
||||
- Test client (httpx AsyncClient on the FastAPI app)
|
||||
- Factory functions for creating test hunts, datasets, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Force test database
|
||||
os.environ["TH_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||
os.environ["TH_JWT_SECRET"] = "test-secret-key-for-tests"
|
||||
|
||||
from app.db.engine import Base, get_db
|
||||
from app.main import app
|
||||
|
||||
|
||||
# ── Database fixtures ─────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def test_engine():
|
||||
"""Create test database engine."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a fresh database session for each test."""
|
||||
async_session = sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async test client with overridden DB dependency."""
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ── Factory helpers ───────────────────────────────────────────────────
|
||||
|
||||
def make_csv_bytes(
|
||||
columns: list[str],
|
||||
rows: list[list[str]],
|
||||
delimiter: str = ",",
|
||||
) -> bytes:
|
||||
"""Create CSV content as bytes for upload tests."""
|
||||
lines = [delimiter.join(columns)]
|
||||
for row in rows:
|
||||
lines.append(delimiter.join(str(v) for v in row))
|
||||
return "\n".join(lines).encode("utf-8")
|
||||
|
||||
|
||||
SAMPLE_CSV = make_csv_bytes(
|
||||
["timestamp", "hostname", "src_ip", "dst_ip", "process_name", "command_line"],
|
||||
[
|
||||
["2025-01-15T10:30:00Z", "DESKTOP-ABC", "192.168.1.100", "10.0.0.50", "cmd.exe", "cmd /c whoami"],
|
||||
["2025-01-15T10:31:00Z", "DESKTOP-ABC", "192.168.1.100", "10.0.0.51", "powershell.exe", "powershell -enc SGVsbG8="],
|
||||
["2025-01-15T10:32:00Z", "DESKTOP-XYZ", "192.168.1.101", "8.8.8.8", "chrome.exe", "chrome.exe --no-sandbox"],
|
||||
["2025-01-15T10:33:00Z", "DESKTOP-ABC", "192.168.1.100", "203.0.113.5", "svchost.exe", "svchost.exe -k netsvcs"],
|
||||
["2025-01-15T10:34:00Z", "SERVER-DC01", "10.0.0.1", "10.0.0.50", "lsass.exe", "lsass.exe"],
|
||||
],
|
||||
)
|
||||
|
||||
SAMPLE_HASH_CSV = make_csv_bytes(
|
||||
["filename", "md5", "sha256", "size"],
|
||||
[
|
||||
["malware.exe", "d41d8cd98f00b204e9800998ecf8427e", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "1024"],
|
||||
["benign.dll", "098f6bcd4621d373cade4e832627b4f6", "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8", "2048"],
|
||||
],
|
||||
)
|
||||
117
backend/tests/test_agents.py
Normal file
117
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for model registry and task router."""
|
||||
|
||||
import pytest
|
||||
from app.agents.registry import (
|
||||
ModelRegistry, ModelEntry, Capability, Tier, Node,
|
||||
registry, ROADRUNNER_MODELS, WILE_MODELS,
|
||||
)
|
||||
from app.agents.router import TaskRouter, TaskType, task_router
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
"""Tests for the model registry."""
|
||||
|
||||
def test_registry_has_models(self):
|
||||
assert len(registry.models) > 0
|
||||
assert len(ROADRUNNER_MODELS) > 0
|
||||
assert len(WILE_MODELS) > 0
|
||||
|
||||
def test_find_by_capability(self):
|
||||
chat_models = registry.find(capability=Capability.CHAT)
|
||||
assert len(chat_models) > 0
|
||||
for m in chat_models:
|
||||
assert Capability.CHAT in m.capabilities
|
||||
|
||||
def test_find_code_models(self):
|
||||
code_models = registry.find(capability=Capability.CODE)
|
||||
assert len(code_models) > 0
|
||||
|
||||
def test_find_vision_models(self):
|
||||
vision_models = registry.find(capability=Capability.VISION)
|
||||
assert len(vision_models) > 0
|
||||
|
||||
def test_find_embedding_models(self):
|
||||
embed_models = registry.find(capability=Capability.EMBEDDING)
|
||||
assert len(embed_models) > 0
|
||||
|
||||
def test_find_by_node(self):
|
||||
wile_models = registry.find(node=Node.WILE)
|
||||
rr_models = registry.find(node=Node.ROADRUNNER)
|
||||
assert len(wile_models) > 0
|
||||
assert len(rr_models) > 0
|
||||
|
||||
def test_find_heavy_models(self):
|
||||
heavy = registry.find(tier=Tier.HEAVY)
|
||||
assert len(heavy) > 0
|
||||
for m in heavy:
|
||||
assert m.tier == Tier.HEAVY
|
||||
|
||||
def test_get_best(self):
|
||||
best = registry.get_best(Capability.CHAT, prefer_tier=Tier.FAST)
|
||||
assert best is not None
|
||||
assert Capability.CHAT in best.capabilities
|
||||
|
||||
def test_get_best_vision_on_roadrunner(self):
|
||||
best = registry.get_best(Capability.VISION, prefer_node=Node.ROADRUNNER)
|
||||
assert best is not None
|
||||
assert Capability.VISION in best.capabilities
|
||||
|
||||
def test_to_dict(self):
|
||||
result = registry.to_dict()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert "name" in result[0]
|
||||
assert "capabilities" in result[0]
|
||||
|
||||
|
||||
class TestTaskRouter:
|
||||
"""Tests for the task router."""
|
||||
|
||||
def test_route_quick_chat(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT)
|
||||
assert decision.model
|
||||
assert decision.node
|
||||
|
||||
def test_route_deep_analysis(self):
|
||||
decision = task_router.route(TaskType.DEEP_ANALYSIS)
|
||||
assert decision.model
|
||||
# Deep should route to heavy model
|
||||
assert decision.task_type == TaskType.DEEP_ANALYSIS
|
||||
|
||||
def test_route_code_analysis(self):
|
||||
decision = task_router.route(TaskType.CODE_ANALYSIS)
|
||||
assert decision.model
|
||||
assert "coder" in decision.model.lower() or "code" in decision.model.lower()
|
||||
|
||||
def test_route_vision(self):
|
||||
decision = task_router.route(TaskType.VISION)
|
||||
assert decision.model
|
||||
assert decision.node == Node.ROADRUNNER
|
||||
|
||||
def test_route_with_model_override(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT, model_override="llama3.1:latest")
|
||||
assert decision.model == "llama3.1:latest"
|
||||
|
||||
def test_route_unknown_model_to_cluster(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT, model_override="nonexistent-model:99b")
|
||||
assert decision.node == Node.CLUSTER
|
||||
assert decision.provider_type == "openwebui"
|
||||
|
||||
def test_classify_code_task(self):
|
||||
assert task_router.classify_task("deobfuscate this powershell script") == TaskType.CODE_ANALYSIS
|
||||
assert task_router.classify_task("decode this base64 payload") == TaskType.CODE_ANALYSIS
|
||||
|
||||
def test_classify_deep_task(self):
|
||||
assert task_router.classify_task("detailed forensic analysis of this process tree") == TaskType.DEEP_ANALYSIS
|
||||
|
||||
def test_classify_vision_task(self):
|
||||
assert task_router.classify_task("analyze this screenshot", has_image=True) == TaskType.VISION
|
||||
|
||||
def test_classify_quick_task(self):
|
||||
assert task_router.classify_task("what does this process do?") == TaskType.QUICK_CHAT
|
||||
|
||||
def test_debate_model_overrides(self):
|
||||
for task_type in [TaskType.DEBATE_PLANNER, TaskType.DEBATE_CRITIC, TaskType.DEBATE_PRAGMATIST, TaskType.DEBATE_JUDGE]:
|
||||
decision = task_router.route(task_type)
|
||||
assert decision.model
|
||||
assert decision.task_type == task_type
|
||||
189
backend/tests/test_api.py
Normal file
189
backend/tests/test_api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for API endpoints — datasets, hunts, annotations."""
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHealthEndpoints:
|
||||
"""Test basic health endpoints."""
|
||||
|
||||
async def test_root(self, client):
|
||||
resp = await client.get("/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["service"] == "ThreatHunt API"
|
||||
assert data["status"] == "running"
|
||||
|
||||
async def test_openapi_docs(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "/api/agent/assist" in data["paths"]
|
||||
assert "/api/datasets/upload" in data["paths"]
|
||||
assert "/api/hunts" in data["paths"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHuntEndpoints:
|
||||
"""Test hunt CRUD operations."""
|
||||
|
||||
async def test_create_hunt(self, client):
|
||||
resp = await client.post("/api/hunts", json={
|
||||
"name": "Test Hunt",
|
||||
"description": "Testing hunt creation",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Test Hunt"
|
||||
assert data["status"] == "active"
|
||||
assert data["id"]
|
||||
|
||||
async def test_list_hunts(self, client):
|
||||
# Create a hunt first
|
||||
await client.post("/api/hunts", json={"name": "Hunt 1"})
|
||||
await client.post("/api/hunts", json={"name": "Hunt 2"})
|
||||
|
||||
resp = await client.get("/api/hunts")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 2
|
||||
|
||||
async def test_get_hunt(self, client):
|
||||
# Create
|
||||
create_resp = await client.post("/api/hunts", json={"name": "Specific Hunt"})
|
||||
hunt_id = create_resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/hunts/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Specific Hunt"
|
||||
|
||||
async def test_update_hunt(self, client):
|
||||
create_resp = await client.post("/api/hunts", json={"name": "Original"})
|
||||
hunt_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.put(f"/api/hunts/{hunt_id}", json={
|
||||
"name": "Updated",
|
||||
"status": "closed",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Updated"
|
||||
assert resp.json()["status"] == "closed"
|
||||
|
||||
async def test_get_nonexistent_hunt(self, client):
|
||||
resp = await client.get("/api/hunts/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDatasetEndpoints:
|
||||
"""Test dataset upload and retrieval."""
|
||||
|
||||
async def test_upload_csv(self, client):
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
resp = await client.post(
|
||||
"/api/datasets/upload",
|
||||
files=files,
|
||||
params={"name": "Test Dataset"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Test Dataset"
|
||||
assert data["row_count"] == 5
|
||||
assert "timestamp" in data["columns"]
|
||||
|
||||
async def test_upload_invalid_extension(self, client):
|
||||
files = {"file": ("bad.exe", io.BytesIO(b"not csv"), "application/octet-stream")}
|
||||
resp = await client.post("/api/datasets/upload", files=files)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_upload_empty_file(self, client):
|
||||
files = {"file": ("empty.csv", io.BytesIO(b""), "text/csv")}
|
||||
resp = await client.post("/api/datasets/upload", files=files)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_list_datasets(self, client):
|
||||
# Upload first
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
await client.post("/api/datasets/upload", files=files, params={"name": "DS1"})
|
||||
|
||||
resp = await client.get("/api/datasets")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
async def test_get_dataset_rows(self, client):
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload_resp = await client.post("/api/datasets/upload", files=files, params={"name": "RowTest"})
|
||||
ds_id = upload_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/datasets/{ds_id}/rows")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 5
|
||||
assert len(data["rows"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAnnotationEndpoints:
|
||||
"""Test annotation CRUD."""
|
||||
|
||||
async def test_create_annotation(self, client):
|
||||
resp = await client.post("/api/annotations", json={
|
||||
"text": "Suspicious process detected",
|
||||
"severity": "high",
|
||||
"tag": "suspicious",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["text"] == "Suspicious process detected"
|
||||
assert data["severity"] == "high"
|
||||
|
||||
async def test_list_annotations(self, client):
|
||||
await client.post("/api/annotations", json={"text": "Ann 1", "severity": "info"})
|
||||
await client.post("/api/annotations", json={"text": "Ann 2", "severity": "critical"})
|
||||
|
||||
resp = await client.get("/api/annotations")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["total"] >= 2
|
||||
|
||||
async def test_filter_annotations_by_severity(self, client):
|
||||
await client.post("/api/annotations", json={"text": "Critical finding", "severity": "critical"})
|
||||
|
||||
resp = await client.get("/api/annotations", params={"severity": "critical"})
|
||||
assert resp.status_code == 200
|
||||
for ann in resp.json()["annotations"]:
|
||||
assert ann["severity"] == "critical"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHypothesisEndpoints:
|
||||
"""Test hypothesis CRUD."""
|
||||
|
||||
async def test_create_hypothesis(self, client):
|
||||
resp = await client.post("/api/hypotheses", json={
|
||||
"title": "Living off the Land",
|
||||
"description": "Attacker using LOLBins for execution",
|
||||
"mitre_technique": "T1059",
|
||||
"status": "active",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["title"] == "Living off the Land"
|
||||
assert data["mitre_technique"] == "T1059"
|
||||
|
||||
async def test_update_hypothesis_status(self, client):
|
||||
create_resp = await client.post("/api/hypotheses", json={
|
||||
"title": "Test Hyp",
|
||||
"status": "draft",
|
||||
})
|
||||
hyp_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.put(f"/api/hypotheses/{hyp_id}", json={
|
||||
"status": "confirmed",
|
||||
"evidence_notes": "Confirmed via process tree analysis",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "confirmed"
|
||||
104
backend/tests/test_csv_parser.py
Normal file
104
backend/tests/test_csv_parser.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
|
||||
import pytest
|
||||
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
||||
from app.services.normalizer import normalize_columns, normalize_rows, detect_ioc_columns, detect_time_range
|
||||
from tests.conftest import SAMPLE_CSV, SAMPLE_HASH_CSV, make_csv_bytes
|
||||
|
||||
|
||||
class TestCSVParser:
|
||||
"""Tests for CSV parsing."""
|
||||
|
||||
def test_parse_csv_basic(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert len(rows) == 5
|
||||
assert "timestamp" in meta["columns"]
|
||||
assert "hostname" in meta["columns"]
|
||||
assert meta["encoding"] is not None
|
||||
assert meta["delimiter"] == ","
|
||||
|
||||
def test_parse_csv_columns(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert meta["columns"] == ["timestamp", "hostname", "src_ip", "dst_ip", "process_name", "command_line"]
|
||||
|
||||
def test_parse_csv_row_data(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert rows[0]["hostname"] == "DESKTOP-ABC"
|
||||
assert rows[0]["src_ip"] == "192.168.1.100"
|
||||
assert rows[2]["process_name"] == "chrome.exe"
|
||||
|
||||
def test_parse_csv_hash_file(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_HASH_CSV)
|
||||
assert len(rows) == 2
|
||||
assert "md5" in meta["columns"]
|
||||
assert "sha256" in meta["columns"]
|
||||
|
||||
def test_parse_tsv(self):
|
||||
tsv_data = make_csv_bytes(
|
||||
["host", "ip", "port"],
|
||||
[["server1", "10.0.0.1", "443"], ["server2", "10.0.0.2", "80"]],
|
||||
delimiter="\t",
|
||||
)
|
||||
rows, meta = parse_csv_bytes(tsv_data)
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_parse_empty_file(self):
|
||||
with pytest.raises(Exception):
|
||||
parse_csv_bytes(b"")
|
||||
|
||||
def test_detect_encoding_utf8(self):
|
||||
enc = detect_encoding(SAMPLE_CSV)
|
||||
assert enc is not None
|
||||
assert "ascii" in enc.lower() or "utf" in enc.lower()
|
||||
|
||||
def test_infer_column_types(self):
|
||||
types = infer_column_types(
|
||||
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
|
||||
"src_ip",
|
||||
)
|
||||
assert types == "ip"
|
||||
|
||||
def test_infer_column_types_hash(self):
|
||||
types = infer_column_types(
|
||||
["d41d8cd98f00b204e9800998ecf8427e"],
|
||||
"hash",
|
||||
)
|
||||
assert types == "hash_md5"
|
||||
|
||||
|
||||
class TestNormalizer:
|
||||
"""Tests for column normalization."""
|
||||
|
||||
def test_normalize_columns(self):
|
||||
mapping = normalize_columns(["SourceAddr", "DestAddr", "ProcessName"])
|
||||
assert "SourceAddr" in mapping
|
||||
# Should map to canonical names
|
||||
assert mapping.get("SourceAddr") in ("src_ip", "source_address", None) or isinstance(mapping.get("SourceAddr"), str)
|
||||
|
||||
def test_normalize_known_columns(self):
|
||||
mapping = normalize_columns(["timestamp", "hostname", "src_ip"])
|
||||
assert mapping.get("timestamp") == "timestamp"
|
||||
assert mapping.get("hostname") == "hostname"
|
||||
assert mapping.get("src_ip") == "src_ip"
|
||||
|
||||
def test_detect_ioc_columns(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
column_mapping = normalize_columns(meta["columns"])
|
||||
iocs = detect_ioc_columns(meta["columns"], meta["column_types"], column_mapping)
|
||||
# Should detect IP columns
|
||||
assert isinstance(iocs, dict)
|
||||
|
||||
def test_detect_time_range(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
column_mapping = normalize_columns(meta["columns"])
|
||||
start, end = detect_time_range(rows, column_mapping)
|
||||
# Should detect time range from timestamp column
|
||||
if start:
|
||||
assert "2025" in start
|
||||
|
||||
def test_normalize_rows(self):
|
||||
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
||||
mapping = {"SourceAddr": "src_ip", "ProcessName": "process_name"}
|
||||
normalized = normalize_rows(rows, mapping)
|
||||
assert len(normalized) == 1
|
||||
assert normalized[0].get("src_ip") == "10.0.0.1"
|
||||
199
backend/tests/test_keywords.py
Normal file
199
backend/tests/test_keywords.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Tests for AUP keyword themes, keyword CRUD, and scanner."""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_themes_empty(client: AsyncClient):
|
||||
"""Initially (no seed in tests) the themes list should be empty or seeded."""
|
||||
res = await client.get("/api/keywords/themes")
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert "themes" in data
|
||||
assert "total" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_theme(client: AsyncClient):
|
||||
res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Test Gambling", "color": "#f44336", "enabled": True,
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["name"] == "Test Gambling"
|
||||
assert data["color"] == "#f44336"
|
||||
assert data["enabled"] is True
|
||||
assert data["keyword_count"] == 0
|
||||
return data["id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_theme(client: AsyncClient):
|
||||
await client.post("/api/keywords/themes", json={"name": "Dup Theme"})
|
||||
res = await client.post("/api/keywords/themes", json={"name": "Dup Theme"})
|
||||
assert res.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_theme(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Updatable"})
|
||||
tid = create.json()["id"]
|
||||
res = await client.put(f"/api/keywords/themes/{tid}", json={
|
||||
"name": "Updated Name", "color": "#00ff00", "enabled": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["color"] == "#00ff00"
|
||||
assert data["enabled"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_theme(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "ToDelete"})
|
||||
tid = create.json()["id"]
|
||||
res = await client.delete(f"/api/keywords/themes/{tid}")
|
||||
assert res.status_code == 204
|
||||
|
||||
# Verify gone
|
||||
check = await client.get("/api/keywords/themes")
|
||||
names = [t["name"] for t in check.json()["themes"]]
|
||||
assert "ToDelete" not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_theme(client: AsyncClient):
|
||||
res = await client.delete("/api/keywords/themes/nonexistent")
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keyword(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "KW Test Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
res = await client.post(f"/api/keywords/themes/{tid}/keywords", json={
|
||||
"value": "poker", "is_regex": False,
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["value"] == "poker"
|
||||
assert data["is_regex"] is False
|
||||
assert data["theme_id"] == tid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keywords_bulk(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Bulk KW Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
res = await client.post(f"/api/keywords/themes/{tid}/keywords/bulk", json={
|
||||
"values": ["steam", "epic games", "discord"],
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["added"] == 3
|
||||
assert data["theme_id"] == tid
|
||||
|
||||
# Verify via theme list
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
theme = [t for t in themes.json()["themes"] if t["id"] == tid][0]
|
||||
assert theme["keyword_count"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keyword(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Del KW Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
kw_res = await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "removeme"})
|
||||
kw_id = kw_res.json()["id"]
|
||||
|
||||
res = await client.delete(f"/api/keywords/keywords/{kw_id}")
|
||||
assert res.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keyword_to_nonexistent_theme(client: AsyncClient):
|
||||
res = await client.post("/api/keywords/themes/fakeid/keywords", json={"value": "test"})
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# ── Scanner ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_empty(client: AsyncClient):
|
||||
"""Scan with no data should return zero hits."""
|
||||
res = await client.post("/api/keywords/scan", json={})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["total_hits"] == 0
|
||||
assert data["hits"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_with_dataset(client: AsyncClient):
|
||||
"""Upload a dataset with known keywords, verify scanner finds them."""
|
||||
# Create a theme + keyword
|
||||
theme_res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Scan Test", "color": "#ff0000",
|
||||
})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||
|
||||
# Upload CSV dataset that contains "chrome.exe"
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("test_scan.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
assert upload.status_code == 200
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
# Scan
|
||||
res = await client.post("/api/keywords/scan", json={
|
||||
"dataset_ids": [ds_id],
|
||||
"theme_ids": [tid],
|
||||
"scan_hunts": False,
|
||||
"scan_annotations": False,
|
||||
"scan_messages": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["total_hits"] > 0
|
||||
# Verify the hit references chrome.exe
|
||||
kw_hits = [h for h in data["hits"] if h["keyword"] == "chrome.exe"]
|
||||
assert len(kw_hits) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_scan(client: AsyncClient):
|
||||
"""Quick scan endpoint should work with a dataset_id parameter."""
|
||||
# Create theme + keyword
|
||||
theme_res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Quick Scan Theme", "color": "#00ff00",
|
||||
})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "powershell"})
|
||||
|
||||
# Upload dataset
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("quick_scan.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
res = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert "total_hits" in data
|
||||
# powershell should match at least one row
|
||||
assert data["total_hits"] > 0
|
||||
@@ -1,5 +1,3 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
backend:
|
||||
build:
|
||||
@@ -9,33 +7,29 @@ services:
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
# Agent provider configuration
|
||||
# Set one of these to enable the agent:
|
||||
# THREAT_HUNT_AGENT_PROVIDER=local
|
||||
# THREAT_HUNT_LOCAL_MODEL_PATH=/models/model.gguf
|
||||
#
|
||||
# THREAT_HUNT_AGENT_PROVIDER=networked
|
||||
# THREAT_HUNT_NETWORKED_ENDPOINT=http://inference-service:5000
|
||||
# THREAT_HUNT_NETWORKED_KEY=your-api-key
|
||||
#
|
||||
# THREAT_HUNT_AGENT_PROVIDER=online
|
||||
# THREAT_HUNT_ONLINE_API_KEY=sk-your-openai-key
|
||||
# THREAT_HUNT_ONLINE_MODEL=gpt-3.5-turbo
|
||||
# ── LLM Cluster (Wile / Roadrunner via Tailscale) ──
|
||||
TH_WILE_HOST: "100.110.190.12"
|
||||
TH_ROADRUNNER_HOST: "100.110.190.11"
|
||||
TH_OLLAMA_PORT: "11434"
|
||||
TH_OPEN_WEBUI_URL: "https://ai.guapo613.beer"
|
||||
|
||||
# Auto-detect available provider (tries local -> networked -> online)
|
||||
THREAT_HUNT_AGENT_PROVIDER: auto
|
||||
# ── Database ──
|
||||
TH_DATABASE_URL: "sqlite+aiosqlite:///./threathunt.db"
|
||||
|
||||
# Optional agent settings
|
||||
THREAT_HUNT_AGENT_MAX_TOKENS: "1024"
|
||||
THREAT_HUNT_AGENT_REASONING: "true"
|
||||
THREAT_HUNT_AGENT_HISTORY_LENGTH: "10"
|
||||
THREAT_HUNT_AGENT_FILTER_SENSITIVE: "true"
|
||||
# ── Auth ──
|
||||
TH_JWT_SECRET: "change-me-in-production"
|
||||
|
||||
# ── Enrichment API keys (set your own) ──
|
||||
# TH_VIRUSTOTAL_API_KEY: ""
|
||||
# TH_ABUSEIPDB_API_KEY: ""
|
||||
# TH_SHODAN_API_KEY: ""
|
||||
|
||||
# ── Agent behaviour ──
|
||||
TH_AGENT_MAX_TOKENS: "4096"
|
||||
TH_AGENT_TEMPERATURE: "0.3"
|
||||
volumes:
|
||||
# Optional: Mount local models for local provider
|
||||
# - ./models:/models:ro
|
||||
- ./backend:/app
|
||||
depends_on:
|
||||
- frontend
|
||||
- backend-data:/app/data
|
||||
networks:
|
||||
- threathunt
|
||||
healthcheck:
|
||||
@@ -52,9 +46,8 @@ services:
|
||||
container_name: threathunt-frontend
|
||||
ports:
|
||||
- "3000:3000"
|
||||
environment:
|
||||
# API endpoint configuration
|
||||
REACT_APP_API_URL: http://localhost:8000
|
||||
depends_on:
|
||||
- backend
|
||||
networks:
|
||||
- threathunt
|
||||
healthcheck:
|
||||
@@ -69,8 +62,5 @@ networks:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
# Optional: Persistent storage for models or data
|
||||
# models:
|
||||
# driver: local
|
||||
# data:
|
||||
# driver: local
|
||||
backend-data:
|
||||
driver: local
|
||||
|
||||
31
frontend/nginx.conf
Normal file
31
frontend/nginx.conf
Normal file
@@ -0,0 +1,31 @@
|
||||
server {
|
||||
listen 3000;
|
||||
server_name _;
|
||||
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
|
||||
# Allow large CSV uploads (matches backend 500 MB limit)
|
||||
client_max_body_size 500M;
|
||||
|
||||
# Proxy API requests to the backend service
|
||||
location /api/ {
|
||||
proxy_pass http://backend:8000;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_read_timeout 120s;
|
||||
}
|
||||
|
||||
# SPA fallback — serve index.html for all non-file routes
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
# Cache static assets
|
||||
location /static/ {
|
||||
expires 1y;
|
||||
add_header Cache-Control "public, immutable";
|
||||
}
|
||||
}
|
||||
1947
frontend/package-lock.json
generated
1947
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -3,8 +3,16 @@
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"dependencies": {
|
||||
"@emotion/react": "^11.14.0",
|
||||
"@emotion/styled": "^11.14.1",
|
||||
"@mui/icons-material": "^7.3.8",
|
||||
"@mui/material": "^7.3.8",
|
||||
"@mui/x-data-grid": "^8.27.1",
|
||||
"@types/react-router-dom": "^5.3.3",
|
||||
"notistack": "^3.0.2",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-router-dom": "^7.13.0",
|
||||
"react-scripts": "5.0.1"
|
||||
},
|
||||
"scripts": {
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#1976d2" />
|
||||
<meta
|
||||
name="description"
|
||||
content="ThreatHunt - Analyst-assist threat hunting platform with agent guidance"
|
||||
/>
|
||||
<title>ThreatHunt - Threat Hunting with Agent Assistance</title>
|
||||
</head>
|
||||
<body>
|
||||
<noscript>You need to enable JavaScript to run this app.</noscript>
|
||||
<div id="root"></div>
|
||||
</body>
|
||||
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#0b0c0d" />
|
||||
<meta name="description" content="ThreatHunt - Analyst-assist threat hunting platform with agent guidance" />
|
||||
<title>ThreatHunt Command Deck</title>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<noscript>You need to enable JavaScript to run this app.</noscript>
|
||||
<div id="root"></div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
@@ -1,123 +1,137 @@
|
||||
/**
|
||||
* Main ThreatHunt application entry point.
|
||||
* ThreatHunt — MUI-powered analyst-assist platform.
|
||||
*/
|
||||
|
||||
import React, { useState } from "react";
|
||||
import "./App.css";
|
||||
import AgentPanel from "./components/AgentPanel";
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import { BrowserRouter, Routes, Route, useNavigate, useLocation } from 'react-router-dom';
|
||||
import { ThemeProvider, CssBaseline, Box, AppBar, Toolbar, Typography, IconButton,
|
||||
Drawer, List, ListItemButton, ListItemIcon, ListItemText, Divider, Chip } from '@mui/material';
|
||||
import MenuIcon from '@mui/icons-material/Menu';
|
||||
import DashboardIcon from '@mui/icons-material/Dashboard';
|
||||
import SearchIcon from '@mui/icons-material/Search';
|
||||
import StorageIcon from '@mui/icons-material/Storage';
|
||||
import UploadFileIcon from '@mui/icons-material/UploadFile';
|
||||
import SmartToyIcon from '@mui/icons-material/SmartToy';
|
||||
import SecurityIcon from '@mui/icons-material/Security';
|
||||
import BookmarkIcon from '@mui/icons-material/Bookmark';
|
||||
import ScienceIcon from '@mui/icons-material/Science';
|
||||
import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
|
||||
import GppMaybeIcon from '@mui/icons-material/GppMaybe';
|
||||
import HubIcon from '@mui/icons-material/Hub';
|
||||
import { SnackbarProvider } from 'notistack';
|
||||
import theme from './theme';
|
||||
|
||||
function App() {
|
||||
// Sample state for demonstration
|
||||
const [currentDataset] = useState("FileList-2025-12-26");
|
||||
const [currentHost] = useState("DESKTOP-ABC123");
|
||||
const [currentArtifact] = useState("FileList");
|
||||
const [dataDescription] = useState(
|
||||
"File listing from system scan showing recent modifications"
|
||||
);
|
||||
import Dashboard from './components/Dashboard';
|
||||
import HuntManager from './components/HuntManager';
|
||||
import DatasetViewer from './components/DatasetViewer';
|
||||
import FileUpload from './components/FileUpload';
|
||||
import AgentPanel from './components/AgentPanel';
|
||||
import EnrichmentPanel from './components/EnrichmentPanel';
|
||||
import AnnotationPanel from './components/AnnotationPanel';
|
||||
import HypothesisTracker from './components/HypothesisTracker';
|
||||
import CorrelationView from './components/CorrelationView';
|
||||
import AUPScanner from './components/AUPScanner';
|
||||
import NetworkMap from './components/NetworkMap';
|
||||
|
||||
const handleAnalysisAction = (action: string) => {
|
||||
console.log("Analysis action triggered:", action);
|
||||
// In a real app, this would update the analysis view or apply filters
|
||||
};
|
||||
const DRAWER_WIDTH = 240;
|
||||
|
||||
interface NavItem { label: string; path: string; icon: React.ReactNode }
|
||||
|
||||
const NAV: NavItem[] = [
|
||||
{ label: 'Dashboard', path: '/', icon: <DashboardIcon /> },
|
||||
{ label: 'Hunts', path: '/hunts', icon: <SearchIcon /> },
|
||||
{ label: 'Datasets', path: '/datasets', icon: <StorageIcon /> },
|
||||
{ label: 'Upload', path: '/upload', icon: <UploadFileIcon /> },
|
||||
{ label: 'Agent', path: '/agent', icon: <SmartToyIcon /> },
|
||||
{ label: 'Enrichment', path: '/enrichment', icon: <SecurityIcon /> },
|
||||
{ label: 'Annotations', path: '/annotations', icon: <BookmarkIcon /> },
|
||||
{ label: 'Hypotheses', path: '/hypotheses', icon: <ScienceIcon /> },
|
||||
{ label: 'Correlation', path: '/correlation', icon: <CompareArrowsIcon /> },
|
||||
{ label: 'Network Map', path: '/network', icon: <HubIcon /> },
|
||||
{ label: 'AUP Scanner', path: '/aup', icon: <GppMaybeIcon /> },
|
||||
];
|
||||
|
||||
function Shell() {
|
||||
const [open, setOpen] = useState(true);
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
|
||||
const toggle = useCallback(() => setOpen(o => !o), []);
|
||||
|
||||
return (
|
||||
<div className="app">
|
||||
<header className="app-header">
|
||||
<h1>ThreatHunt - Analyst-Assist Platform</h1>
|
||||
<p className="subtitle">
|
||||
Powered by agent guidance for faster threat hunting
|
||||
</p>
|
||||
</header>
|
||||
<Box sx={{ display: 'flex', minHeight: '100vh' }}>
|
||||
{/* App bar */}
|
||||
<AppBar position="fixed" sx={{ zIndex: t => t.zIndex.drawer + 1 }}>
|
||||
<Toolbar variant="dense">
|
||||
<IconButton edge="start" color="inherit" onClick={toggle} sx={{ mr: 1 }}>
|
||||
<MenuIcon />
|
||||
</IconButton>
|
||||
<Typography variant="h6" noWrap sx={{ flexGrow: 1 }}>
|
||||
ThreatHunt
|
||||
</Typography>
|
||||
<Chip label="v0.3.0" size="small" color="primary" variant="outlined" />
|
||||
</Toolbar>
|
||||
</AppBar>
|
||||
|
||||
<main className="app-main">
|
||||
<div className="app-content">
|
||||
<section className="main-panel">
|
||||
<h2>Analysis Dashboard</h2>
|
||||
<p className="placeholder-text">
|
||||
[Main analysis interface would display here]
|
||||
</p>
|
||||
<div className="data-view">
|
||||
<table className="sample-data">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>File</th>
|
||||
<th>Modified</th>
|
||||
<th>Size</th>
|
||||
<th>Hash</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>System32\drivers\etc\hosts</td>
|
||||
<td>2025-12-20 14:32</td>
|
||||
<td>456 B</td>
|
||||
<td>d41d8cd98f00b204...</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Windows\Temp\cache.bin</td>
|
||||
<td>2025-12-26 09:15</td>
|
||||
<td>2.3 MB</td>
|
||||
<td>5d41402abc4b2a76...</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Users\Admin\AppData\Roaming\config.xml</td>
|
||||
<td>2025-12-25 16:45</td>
|
||||
<td>12.4 KB</td>
|
||||
<td>e99a18c428cb38d5...</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</section>
|
||||
{/* Sidebar drawer */}
|
||||
<Drawer
|
||||
variant="persistent"
|
||||
open={open}
|
||||
sx={{
|
||||
width: open ? DRAWER_WIDTH : 0,
|
||||
flexShrink: 0,
|
||||
'& .MuiDrawer-paper': { width: DRAWER_WIDTH, boxSizing: 'border-box', mt: '48px' },
|
||||
}}
|
||||
>
|
||||
<List dense>
|
||||
{NAV.map(item => (
|
||||
<ListItemButton
|
||||
key={item.path}
|
||||
selected={location.pathname === item.path}
|
||||
onClick={() => navigate(item.path)}
|
||||
>
|
||||
<ListItemIcon sx={{ minWidth: 36 }}>{item.icon}</ListItemIcon>
|
||||
<ListItemText primary={item.label} />
|
||||
</ListItemButton>
|
||||
))}
|
||||
</List>
|
||||
<Divider />
|
||||
</Drawer>
|
||||
|
||||
<aside className="agent-sidebar">
|
||||
<AgentPanel
|
||||
dataset_name={currentDataset}
|
||||
artifact_type={currentArtifact}
|
||||
host_identifier={currentHost}
|
||||
data_summary={dataDescription}
|
||||
onAnalysisAction={handleAnalysisAction}
|
||||
/>
|
||||
</aside>
|
||||
</div>
|
||||
</main>
|
||||
{/* Main content */}
|
||||
<Box component="main" sx={{
|
||||
flexGrow: 1, p: 2, mt: '48px',
|
||||
ml: open ? 0 : `-${DRAWER_WIDTH}px`,
|
||||
transition: 'margin 225ms cubic-bezier(0,0,0.2,1)',
|
||||
}}>
|
||||
<Routes>
|
||||
<Route path="/" element={<Dashboard />} />
|
||||
<Route path="/hunts" element={<HuntManager />} />
|
||||
<Route path="/datasets" element={<DatasetViewer />} />
|
||||
<Route path="/upload" element={<FileUpload />} />
|
||||
<Route path="/agent" element={<AgentPanel />} />
|
||||
<Route path="/enrichment" element={<EnrichmentPanel />} />
|
||||
<Route path="/annotations" element={<AnnotationPanel />} />
|
||||
<Route path="/hypotheses" element={<HypothesisTracker />} />
|
||||
<Route path="/correlation" element={<CorrelationView />} />
|
||||
<Route path="/network" element={<NetworkMap />} />
|
||||
<Route path="/aup" element={<AUPScanner />} />
|
||||
</Routes>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
<footer className="app-footer">
|
||||
<div className="footer-content">
|
||||
<div className="footer-section">
|
||||
<h4>About Analyst-Assist Agent</h4>
|
||||
<p>
|
||||
The agent provides advisory guidance on artifact data, analytical
|
||||
pivots, and hypotheses. All decisions remain with the analyst.
|
||||
</p>
|
||||
</div>
|
||||
<div className="footer-section">
|
||||
<h4>Capabilities</h4>
|
||||
<ul>
|
||||
<li>Interpret CSV artifact data</li>
|
||||
<li>Suggest analytical directions</li>
|
||||
<li>Highlight anomalies</li>
|
||||
<li>Propose investigative steps</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div className="footer-section">
|
||||
<h4>Governance</h4>
|
||||
<ul>
|
||||
<li>Read-only guidance</li>
|
||||
<li>No tool execution</li>
|
||||
<li>No autonomous actions</li>
|
||||
<li>Analyst controls decisions</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div className="footer-bottom">
|
||||
<p>
|
||||
© 2025 ThreatHunt. Agent guidance is advisory only. All
|
||||
analytical decisions remain with the analyst.
|
||||
</p>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
function App() {
|
||||
return (
|
||||
<ThemeProvider theme={theme}>
|
||||
<CssBaseline />
|
||||
<SnackbarProvider maxSnack={3} anchorOrigin={{ vertical: 'bottom', horizontal: 'right' }}>
|
||||
<BrowserRouter>
|
||||
<Shell />
|
||||
</BrowserRouter>
|
||||
</SnackbarProvider>
|
||||
</ThemeProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
370
frontend/src/api/client.ts
Normal file
370
frontend/src/api/client.ts
Normal file
@@ -0,0 +1,370 @@
|
||||
/* ====================================================================
|
||||
ThreatHunt API Client — mirrors every backend endpoint.
|
||||
All requests go through the CRA proxy (see package.json "proxy").
|
||||
==================================================================== */
|
||||
|
||||
const BASE = ''; // proxied to http://localhost:8000 by CRA
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
let authToken: string | null = localStorage.getItem('th_token');
|
||||
|
||||
export function setToken(t: string | null) {
|
||||
authToken = t;
|
||||
if (t) localStorage.setItem('th_token', t);
|
||||
else localStorage.removeItem('th_token');
|
||||
}
|
||||
export function getToken() { return authToken; }
|
||||
|
||||
async function api<T = any>(
|
||||
path: string,
|
||||
opts: RequestInit = {},
|
||||
): Promise<T> {
|
||||
const headers: Record<string, string> = {
|
||||
...(opts.headers as Record<string, string> || {}),
|
||||
};
|
||||
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
|
||||
if (!(opts.body instanceof FormData)) headers['Content-Type'] = 'application/json';
|
||||
|
||||
const res = await fetch(`${BASE}${path}`, { ...opts, headers });
|
||||
if (!res.ok) {
|
||||
const body = await res.json().catch(() => ({}));
|
||||
throw new Error(body.detail || `HTTP ${res.status}`);
|
||||
}
|
||||
const ct = res.headers.get('content-type') || '';
|
||||
if (ct.includes('application/json')) return res.json();
|
||||
return res.text() as unknown as T;
|
||||
}
|
||||
|
||||
// ── Auth ─────────────────────────────────────────────────────────────
|
||||
|
||||
export interface UserPayload {
|
||||
id: string; username: string; email: string;
|
||||
display_name: string | null; role: string; is_active: boolean; created_at: string;
|
||||
}
|
||||
export interface AuthPayload {
|
||||
user: UserPayload;
|
||||
tokens: { access_token: string; refresh_token: string; token_type: string };
|
||||
}
|
||||
|
||||
export const auth = {
|
||||
register: (username: string, email: string, password: string, display_name?: string) =>
|
||||
api<AuthPayload>('/api/auth/register', {
|
||||
method: 'POST', body: JSON.stringify({ username, email, password, display_name }),
|
||||
}),
|
||||
login: (username: string, password: string) =>
|
||||
api<AuthPayload>('/api/auth/login', {
|
||||
method: 'POST', body: JSON.stringify({ username, password }),
|
||||
}),
|
||||
refresh: (refresh_token: string) =>
|
||||
api<AuthPayload>('/api/auth/refresh', {
|
||||
method: 'POST', body: JSON.stringify({ refresh_token }),
|
||||
}),
|
||||
me: () => api<UserPayload>('/api/auth/me'),
|
||||
};
|
||||
|
||||
// ── Hunts ────────────────────────────────────────────────────────────
|
||||
|
||||
export interface Hunt {
|
||||
id: string; name: string; description: string | null; status: string;
|
||||
owner_id: string | null; created_at: string; updated_at: string;
|
||||
dataset_count: number; hypothesis_count: number;
|
||||
}
|
||||
|
||||
export const hunts = {
|
||||
list: (skip = 0, limit = 50) =>
|
||||
api<{ hunts: Hunt[]; total: number }>(`/api/hunts?skip=${skip}&limit=${limit}`),
|
||||
get: (id: string) => api<Hunt>(`/api/hunts/${id}`),
|
||||
create: (name: string, description?: string) =>
|
||||
api<Hunt>('/api/hunts', { method: 'POST', body: JSON.stringify({ name, description }) }),
|
||||
update: (id: string, data: Partial<{ name: string; description: string; status: string }>) =>
|
||||
api<Hunt>(`/api/hunts/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/hunts/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Datasets ─────────────────────────────────────────────────────────
|
||||
|
||||
export interface DatasetSummary {
|
||||
id: string; name: string; filename: string; source_tool: string | null;
|
||||
row_count: number; column_schema: Record<string, string> | null;
|
||||
normalized_columns: Record<string, string> | null;
|
||||
ioc_columns: Record<string, string[]> | null;
|
||||
file_size_bytes: number; encoding: string | null; delimiter: string | null;
|
||||
time_range_start: string | null; time_range_end: string | null;
|
||||
hunt_id: string | null; created_at: string;
|
||||
}
|
||||
|
||||
export interface UploadResult {
|
||||
id: string; name: string; row_count: number; columns: string[];
|
||||
column_types: Record<string, string>; normalized_columns: Record<string, string>;
|
||||
ioc_columns: Record<string, string[]>; message: string;
|
||||
}
|
||||
|
||||
export const datasets = {
|
||||
list: (skip = 0, limit = 50, huntId?: string) => {
|
||||
let qs = `/api/datasets?skip=${skip}&limit=${limit}`;
|
||||
if (huntId) qs += `&hunt_id=${encodeURIComponent(huntId)}`;
|
||||
return api<{ datasets: DatasetSummary[]; total: number }>(qs);
|
||||
},
|
||||
get: (id: string) => api<DatasetSummary>(`/api/datasets/${id}`),
|
||||
rows: (id: string, offset = 0, limit = 100) =>
|
||||
api<{ rows: Record<string, any>[]; total: number; offset: number; limit: number }>(
|
||||
`/api/datasets/${id}/rows?offset=${offset}&limit=${limit}`,
|
||||
),
|
||||
upload: (file: File, huntId?: string) => {
|
||||
const fd = new FormData();
|
||||
fd.append('file', file);
|
||||
const qs = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||
return api<UploadResult>(`/api/datasets/upload${qs}`, { method: 'POST', body: fd });
|
||||
},
|
||||
/** Upload with real progress percentage via XMLHttpRequest. */
|
||||
uploadWithProgress: (
|
||||
file: File,
|
||||
huntId?: string,
|
||||
onProgress?: (pct: number) => void,
|
||||
): Promise<UploadResult> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const xhr = new XMLHttpRequest();
|
||||
const fd = new FormData();
|
||||
fd.append('file', file);
|
||||
const qs = huntId ? `?hunt_id=${encodeURIComponent(huntId)}` : '';
|
||||
|
||||
xhr.upload.addEventListener('progress', (e) => {
|
||||
if (e.lengthComputable && onProgress) {
|
||||
onProgress(Math.round((e.loaded / e.total) * 100));
|
||||
}
|
||||
});
|
||||
xhr.addEventListener('load', () => {
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
resolve(JSON.parse(xhr.responseText));
|
||||
} else {
|
||||
try {
|
||||
const body = JSON.parse(xhr.responseText);
|
||||
reject(new Error(body.detail || `HTTP ${xhr.status}`));
|
||||
} catch { reject(new Error(`HTTP ${xhr.status}`)); }
|
||||
}
|
||||
});
|
||||
xhr.addEventListener('error', () => reject(new Error('Network error')));
|
||||
xhr.addEventListener('abort', () => reject(new Error('Upload aborted')));
|
||||
|
||||
xhr.open('POST', `${BASE}/api/datasets/upload${qs}`);
|
||||
if (authToken) xhr.setRequestHeader('Authorization', `Bearer ${authToken}`);
|
||||
xhr.send(fd);
|
||||
});
|
||||
},
|
||||
delete: (id: string) => api(`/api/datasets/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Agent ────────────────────────────────────────────────────────────
|
||||
|
||||
export interface AssistRequest {
|
||||
query: string;
|
||||
dataset_name?: string; artifact_type?: string; host_identifier?: string;
|
||||
data_summary?: string; conversation_history?: { role: string; content: string }[];
|
||||
active_hypotheses?: string[]; annotations_summary?: string;
|
||||
enrichment_summary?: string; mode?: 'quick' | 'deep' | 'debate';
|
||||
model_override?: string; conversation_id?: string; hunt_id?: string;
|
||||
}
|
||||
|
||||
export interface AssistResponse {
|
||||
guidance: string; confidence: number; suggested_pivots: string[];
|
||||
suggested_filters: string[]; caveats: string | null; reasoning: string | null;
|
||||
sans_references: string[]; model_used: string; node_used: string;
|
||||
latency_ms: number; perspectives: Record<string, any>[] | null;
|
||||
conversation_id: string | null;
|
||||
}
|
||||
|
||||
export interface NodeInfo { url: string; available: boolean }
|
||||
export interface HealthInfo {
|
||||
status: string;
|
||||
nodes: { wile: NodeInfo; roadrunner: NodeInfo; cluster: NodeInfo };
|
||||
rag: { available: boolean; url: string; model: string };
|
||||
default_models: Record<string, string>;
|
||||
config: { max_tokens: number; temperature: number };
|
||||
}
|
||||
|
||||
export const agent = {
|
||||
assist: (req: AssistRequest) =>
|
||||
api<AssistResponse>('/api/agent/assist', { method: 'POST', body: JSON.stringify(req) }),
|
||||
health: () => api<HealthInfo>('/api/agent/health'),
|
||||
models: () => api<Record<string, any>>('/api/agent/models'),
|
||||
/** Returns a ReadableStream for SSE streaming */
|
||||
assistStream: async (req: AssistRequest): Promise<Response> => {
|
||||
const headers: Record<string, string> = { 'Content-Type': 'application/json' };
|
||||
if (authToken) headers['Authorization'] = `Bearer ${authToken}`;
|
||||
return fetch(`${BASE}/api/agent/assist/stream`, {
|
||||
method: 'POST', headers, body: JSON.stringify(req),
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
// ── Annotations ──────────────────────────────────────────────────────
|
||||
|
||||
export interface AnnotationData {
|
||||
id: string; row_id: number | null; dataset_id: string | null;
|
||||
author_id: string | null; text: string; severity: string;
|
||||
tag: string | null; highlight_color: string | null;
|
||||
created_at: string; updated_at: string;
|
||||
}
|
||||
|
||||
export const annotations = {
|
||||
list: (params?: { dataset_id?: string; severity?: string; tag?: string; skip?: number; limit?: number }) => {
|
||||
const q = new URLSearchParams();
|
||||
if (params?.dataset_id) q.set('dataset_id', params.dataset_id);
|
||||
if (params?.severity) q.set('severity', params.severity);
|
||||
if (params?.tag) q.set('tag', params.tag);
|
||||
if (params?.skip) q.set('skip', String(params.skip));
|
||||
if (params?.limit) q.set('limit', String(params.limit));
|
||||
return api<{ annotations: AnnotationData[]; total: number }>(`/api/annotations?${q}`);
|
||||
},
|
||||
create: (data: { row_id?: number; dataset_id?: string; text: string; severity?: string; tag?: string; highlight_color?: string }) =>
|
||||
api<AnnotationData>('/api/annotations', { method: 'POST', body: JSON.stringify(data) }),
|
||||
update: (id: string, data: Partial<{ text: string; severity: string; tag: string; highlight_color: string }>) =>
|
||||
api<AnnotationData>(`/api/annotations/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/annotations/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Hypotheses ───────────────────────────────────────────────────────
|
||||
|
||||
export interface HypothesisData {
|
||||
id: string; hunt_id: string | null; title: string; description: string | null;
|
||||
mitre_technique: string | null; status: string;
|
||||
evidence_row_ids: number[] | null; evidence_notes: string | null;
|
||||
created_at: string; updated_at: string;
|
||||
}
|
||||
|
||||
export const hypotheses = {
|
||||
list: (params?: { hunt_id?: string; status?: string; skip?: number; limit?: number }) => {
|
||||
const q = new URLSearchParams();
|
||||
if (params?.hunt_id) q.set('hunt_id', params.hunt_id);
|
||||
if (params?.status) q.set('status', params.status);
|
||||
if (params?.skip) q.set('skip', String(params.skip));
|
||||
if (params?.limit) q.set('limit', String(params.limit));
|
||||
return api<{ hypotheses: HypothesisData[]; total: number }>(`/api/hypotheses?${q}`);
|
||||
},
|
||||
create: (data: { hunt_id?: string; title: string; description?: string; mitre_technique?: string; status?: string }) =>
|
||||
api<HypothesisData>('/api/hypotheses', { method: 'POST', body: JSON.stringify(data) }),
|
||||
update: (id: string, data: Partial<{ title: string; description: string; mitre_technique: string; status: string; evidence_row_ids: number[]; evidence_notes: string }>) =>
|
||||
api<HypothesisData>(`/api/hypotheses/${id}`, { method: 'PUT', body: JSON.stringify(data) }),
|
||||
delete: (id: string) => api(`/api/hypotheses/${id}`, { method: 'DELETE' }),
|
||||
};
|
||||
|
||||
// ── Enrichment ───────────────────────────────────────────────────────
|
||||
|
||||
export interface EnrichmentResult {
|
||||
ioc_value: string; ioc_type: string; source: string; verdict: string;
|
||||
score: number; tags: string[]; country: string; asn: string; org: string;
|
||||
last_seen: string; raw_data: Record<string, any>; error: string;
|
||||
latency_ms: number;
|
||||
}
|
||||
|
||||
export const enrichment = {
|
||||
ioc: (ioc_value: string, ioc_type: string, skip_cache = false) =>
|
||||
api<{ ioc_value: string; ioc_type: string; results: EnrichmentResult[]; overall_verdict: string; overall_score: number }>(
|
||||
'/api/enrichment/ioc', { method: 'POST', body: JSON.stringify({ ioc_value, ioc_type, skip_cache }) },
|
||||
),
|
||||
batch: (iocs: { value: string; type: string }[]) =>
|
||||
api<{ results: Record<string, EnrichmentResult[]>; total_enriched: number }>(
|
||||
'/api/enrichment/batch', { method: 'POST', body: JSON.stringify({ iocs }) },
|
||||
),
|
||||
dataset: (datasetId: string) =>
|
||||
api<{ dataset_id: string; iocs_found: number; enriched: number; results: Record<string, any> }>(
|
||||
`/api/enrichment/dataset/${datasetId}`, { method: 'POST' },
|
||||
),
|
||||
status: () => api<Record<string, any>>('/api/enrichment/status'),
|
||||
};
|
||||
|
||||
// ── Correlation ──────────────────────────────────────────────────────
|
||||
|
||||
export interface CorrelationResult {
|
||||
hunt_ids: string[]; summary: string; total_correlations: number;
|
||||
ioc_overlaps: any[]; time_overlaps: any[]; technique_overlaps: any[];
|
||||
host_overlaps: any[];
|
||||
}
|
||||
|
||||
export const correlation = {
|
||||
analyze: (hunt_ids: string[]) =>
|
||||
api<CorrelationResult>('/api/correlation/analyze', {
|
||||
method: 'POST', body: JSON.stringify({ hunt_ids }),
|
||||
}),
|
||||
all: () => api<CorrelationResult>('/api/correlation/all'),
|
||||
ioc: (ioc_value: string) =>
|
||||
api<{ ioc_value: string; occurrences: any[]; total: number }>(`/api/correlation/ioc/${encodeURIComponent(ioc_value)}`),
|
||||
};
|
||||
|
||||
// ── Reports ──────────────────────────────────────────────────────────
|
||||
|
||||
export const reports = {
|
||||
json: (huntId: string) =>
|
||||
api<Record<string, any>>(`/api/reports/hunt/${huntId}?format=json`),
|
||||
html: (huntId: string) =>
|
||||
api<string>(`/api/reports/hunt/${huntId}?format=html`),
|
||||
csv: (huntId: string) =>
|
||||
api<string>(`/api/reports/hunt/${huntId}?format=csv`),
|
||||
summary: (huntId: string) =>
|
||||
api<Record<string, any>>(`/api/reports/hunt/${huntId}/summary`),
|
||||
};
|
||||
|
||||
// ── Root / misc ──────────────────────────────────────────────────────
|
||||
|
||||
export const misc = {
|
||||
root: () => api<{ name: string; version: string; status: string }>('/'),
|
||||
};
|
||||
|
||||
// ── AUP Keywords ─────────────────────────────────────────────────────
|
||||
|
||||
export interface KeywordOut {
|
||||
id: number; theme_id: string; value: string; is_regex: boolean; created_at: string;
|
||||
}
|
||||
export interface ThemeOut {
|
||||
id: string; name: string; color: string; enabled: boolean; is_builtin: boolean;
|
||||
created_at: string; keyword_count: number; keywords: KeywordOut[];
|
||||
}
|
||||
export interface ScanHit {
|
||||
theme_name: string; theme_color: string; keyword: string;
|
||||
source_type: string; source_id: string | number; field: string;
|
||||
matched_value: string; row_index: number | null; dataset_name: string | null;
|
||||
}
|
||||
export interface ScanResponse {
|
||||
total_hits: number; hits: ScanHit[]; themes_scanned: number;
|
||||
keywords_scanned: number; rows_scanned: number;
|
||||
}
|
||||
|
||||
export const keywords = {
|
||||
// Theme CRUD
|
||||
listThemes: () =>
|
||||
api<{ themes: ThemeOut[]; total: number }>('/api/keywords/themes'),
|
||||
createTheme: (name: string, color?: string, enabled?: boolean) =>
|
||||
api<ThemeOut>('/api/keywords/themes', {
|
||||
method: 'POST', body: JSON.stringify({ name, color, enabled }),
|
||||
}),
|
||||
updateTheme: (id: string, data: Partial<{ name: string; color: string; enabled: boolean }>) =>
|
||||
api<ThemeOut>(`/api/keywords/themes/${id}`, {
|
||||
method: 'PUT', body: JSON.stringify(data),
|
||||
}),
|
||||
deleteTheme: (id: string) =>
|
||||
api(`/api/keywords/themes/${id}`, { method: 'DELETE' }),
|
||||
|
||||
// Keyword CRUD
|
||||
addKeyword: (themeId: string, value: string, is_regex = false) =>
|
||||
api<KeywordOut>(`/api/keywords/themes/${themeId}/keywords`, {
|
||||
method: 'POST', body: JSON.stringify({ value, is_regex }),
|
||||
}),
|
||||
addKeywordsBulk: (themeId: string, values: string[], is_regex = false) =>
|
||||
api<{ added: number; theme_id: string }>(`/api/keywords/themes/${themeId}/keywords/bulk`, {
|
||||
method: 'POST', body: JSON.stringify({ values, is_regex }),
|
||||
}),
|
||||
deleteKeyword: (keywordId: number) =>
|
||||
api(`/api/keywords/keywords/${keywordId}`, { method: 'DELETE' }),
|
||||
|
||||
// Scanning
|
||||
scan: (opts: {
|
||||
dataset_ids?: string[]; theme_ids?: string[];
|
||||
scan_hunts?: boolean; scan_annotations?: boolean; scan_messages?: boolean;
|
||||
}) =>
|
||||
api<ScanResponse>('/api/keywords/scan', {
|
||||
method: 'POST', body: JSON.stringify(opts),
|
||||
}),
|
||||
quickScan: (datasetId: string) =>
|
||||
api<ScanResponse>(`/api/keywords/scan/quick?dataset_id=${encodeURIComponent(datasetId)}`),
|
||||
};
|
||||
431
frontend/src/components/AUPScanner.tsx
Normal file
431
frontend/src/components/AUPScanner.tsx
Normal file
@@ -0,0 +1,431 @@
|
||||
/**
|
||||
* AUPScanner — Acceptable Use Policy keyword scanner.
|
||||
*
|
||||
* Three-panel layout:
|
||||
* Left — Theme manager (add/delete themes, expand to see/add keywords)
|
||||
* Right — Scan controls + results DataGrid
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Button, Chip, TextField, IconButton,
|
||||
Accordion, AccordionSummary, AccordionDetails, Switch, FormControlLabel,
|
||||
CircularProgress, Alert,
|
||||
Tooltip, Checkbox, FormGroup, LinearProgress,
|
||||
FormControl, InputLabel, Select, MenuItem,
|
||||
} from '@mui/material';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
import AddIcon from '@mui/icons-material/Add';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import PlayArrowIcon from '@mui/icons-material/PlayArrow';
|
||||
import RefreshIcon from '@mui/icons-material/Refresh';
|
||||
import { DataGrid, type GridColDef } from '@mui/x-data-grid';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import {
|
||||
keywords,
|
||||
datasets,
|
||||
hunts,
|
||||
type Hunt,
|
||||
type ThemeOut,
|
||||
type ScanResponse,
|
||||
type DatasetSummary,
|
||||
} from '../api/client';
|
||||
|
||||
// ── Theme Manager (left panel) ───────────────────────────────────────
|
||||
|
||||
interface ThemeManagerProps {
|
||||
themes: ThemeOut[];
|
||||
onReload: () => void;
|
||||
}
|
||||
|
||||
function ThemeManager({ themes, onReload }: ThemeManagerProps) {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [newThemeName, setNewThemeName] = useState('');
|
||||
const [newThemeColor, setNewThemeColor] = useState('#9e9e9e');
|
||||
const [newKw, setNewKw] = useState<Record<string, string>>({});
|
||||
|
||||
const addTheme = useCallback(async () => {
|
||||
if (!newThemeName.trim()) return;
|
||||
try {
|
||||
await keywords.createTheme(newThemeName.trim(), newThemeColor);
|
||||
enqueueSnackbar(`Theme "${newThemeName}" created`, { variant: 'success' });
|
||||
setNewThemeName('');
|
||||
setNewThemeColor('#9e9e9e');
|
||||
onReload();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [newThemeName, newThemeColor, enqueueSnackbar, onReload]);
|
||||
|
||||
const deleteTheme = useCallback(async (id: string, name: string) => {
|
||||
if (!window.confirm(`Delete theme "${name}" and all its keywords?`)) return;
|
||||
try {
|
||||
await keywords.deleteTheme(id);
|
||||
enqueueSnackbar(`Theme "${name}" deleted`, { variant: 'info' });
|
||||
onReload();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [enqueueSnackbar, onReload]);
|
||||
|
||||
const toggleTheme = useCallback(async (id: string, enabled: boolean) => {
|
||||
try {
|
||||
await keywords.updateTheme(id, { enabled });
|
||||
onReload();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [enqueueSnackbar, onReload]);
|
||||
|
||||
const addKeyword = useCallback(async (themeId: string) => {
|
||||
const val = (newKw[themeId] || '').trim();
|
||||
if (!val) return;
|
||||
try {
|
||||
// Support comma-separated bulk add
|
||||
const values = val.split(',').map(v => v.trim()).filter(Boolean);
|
||||
if (values.length > 1) {
|
||||
await keywords.addKeywordsBulk(themeId, values);
|
||||
enqueueSnackbar(`Added ${values.length} keywords`, { variant: 'success' });
|
||||
} else {
|
||||
await keywords.addKeyword(themeId, values[0]);
|
||||
enqueueSnackbar(`Added "${values[0]}"`, { variant: 'success' });
|
||||
}
|
||||
setNewKw(prev => ({ ...prev, [themeId]: '' }));
|
||||
onReload();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [newKw, enqueueSnackbar, onReload]);
|
||||
|
||||
const deleteKeyword = useCallback(async (kwId: number) => {
|
||||
try {
|
||||
await keywords.deleteKeyword(kwId);
|
||||
onReload();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [enqueueSnackbar, onReload]);
|
||||
|
||||
return (
|
||||
<Paper sx={{ p: 2, height: '100%', overflow: 'auto' }}>
|
||||
<Typography variant="h6" gutterBottom>Keyword Themes</Typography>
|
||||
|
||||
{/* Add new theme */}
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 2 }} alignItems="center">
|
||||
<TextField
|
||||
size="small" label="New theme" value={newThemeName}
|
||||
onChange={e => setNewThemeName(e.target.value)}
|
||||
onKeyDown={e => e.key === 'Enter' && addTheme()}
|
||||
sx={{ flexGrow: 1 }}
|
||||
/>
|
||||
<input
|
||||
type="color" value={newThemeColor}
|
||||
onChange={e => setNewThemeColor(e.target.value)}
|
||||
style={{ width: 36, height: 36, border: 'none', cursor: 'pointer', borderRadius: 4 }}
|
||||
/>
|
||||
<IconButton color="primary" onClick={addTheme} size="small"><AddIcon /></IconButton>
|
||||
</Stack>
|
||||
|
||||
{/* Theme list */}
|
||||
{themes.map(theme => (
|
||||
<Accordion key={theme.id} defaultExpanded={false} disableGutters
|
||||
sx={{ '&:before': { display: 'none' }, mb: 0.5 }}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Stack direction="row" spacing={1} alignItems="center" sx={{ width: '100%', pr: 1 }}>
|
||||
<Chip
|
||||
label={theme.name}
|
||||
size="small"
|
||||
sx={{ bgcolor: theme.color, color: '#fff', fontWeight: 600 }}
|
||||
/>
|
||||
<Typography variant="caption" color="text.secondary" sx={{ flexGrow: 1 }}>
|
||||
{theme.keyword_count} keywords
|
||||
</Typography>
|
||||
<Switch
|
||||
size="small" checked={theme.enabled}
|
||||
onClick={e => e.stopPropagation()}
|
||||
onChange={(_, checked) => toggleTheme(theme.id, checked)}
|
||||
/>
|
||||
<IconButton
|
||||
size="small" color="error"
|
||||
onClick={e => { e.stopPropagation(); deleteTheme(theme.id, theme.name); }}
|
||||
>
|
||||
<DeleteIcon fontSize="small" />
|
||||
</IconButton>
|
||||
</Stack>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails sx={{ pt: 0 }}>
|
||||
{/* Keywords list */}
|
||||
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.5, mb: 1 }}>
|
||||
{theme.keywords.map(kw => (
|
||||
<Chip
|
||||
key={kw.id}
|
||||
label={kw.value}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
onDelete={() => deleteKeyword(kw.id)}
|
||||
sx={{ borderColor: theme.color }}
|
||||
/>
|
||||
))}
|
||||
</Box>
|
||||
{/* Add keyword */}
|
||||
<Stack direction="row" spacing={1} alignItems="center">
|
||||
<TextField
|
||||
size="small" fullWidth
|
||||
placeholder="Add keyword (comma-separated for bulk)"
|
||||
value={newKw[theme.id] || ''}
|
||||
onChange={e => setNewKw(prev => ({ ...prev, [theme.id]: e.target.value }))}
|
||||
onKeyDown={e => e.key === 'Enter' && addKeyword(theme.id)}
|
||||
/>
|
||||
<IconButton size="small" color="primary" onClick={() => addKeyword(theme.id)}>
|
||||
<AddIcon fontSize="small" />
|
||||
</IconButton>
|
||||
</Stack>
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
))}
|
||||
</Paper>
|
||||
);
|
||||
}
|
||||
|
||||
// ── Scan Controls + Results (right panel) ────────────────────────────
|
||||
|
||||
const RESULT_COLUMNS: GridColDef[] = [
|
||||
{
|
||||
field: 'theme_name', headerName: 'Theme', width: 140,
|
||||
renderCell: (params) => (
|
||||
<Chip label={params.value} size="small"
|
||||
sx={{ bgcolor: params.row.theme_color, color: '#fff', fontWeight: 600 }} />
|
||||
),
|
||||
},
|
||||
{ field: 'keyword', headerName: 'Keyword', width: 140 },
|
||||
{ field: 'source_type', headerName: 'Source', width: 120 },
|
||||
{ field: 'dataset_name', headerName: 'Dataset', width: 150 },
|
||||
{ field: 'field', headerName: 'Field', width: 130 },
|
||||
{ field: 'matched_value', headerName: 'Matched Value', flex: 1, minWidth: 200 },
|
||||
{ field: 'row_index', headerName: 'Row #', width: 80, type: 'number' },
|
||||
];
|
||||
|
||||
export default function AUPScanner() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
|
||||
// State
|
||||
const [themes, setThemes] = useState<ThemeOut[]>([]);
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [selectedHuntId, setSelectedHuntId] = useState('');
|
||||
const [dsList, setDsList] = useState<DatasetSummary[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [scanning, setScanning] = useState(false);
|
||||
const [scanResult, setScanResult] = useState<ScanResponse | null>(null);
|
||||
|
||||
// Scan options
|
||||
const [selectedDs, setSelectedDs] = useState<Set<string>>(new Set());
|
||||
const [selectedThemes, setSelectedThemes] = useState<Set<string>>(new Set());
|
||||
const [scanHunts, setScanHunts] = useState(true);
|
||||
const [scanAnnotations, setScanAnnotations] = useState(true);
|
||||
const [scanMessages, setScanMessages] = useState(true);
|
||||
|
||||
// Load themes + hunts
|
||||
const loadData = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const [tRes, hRes] = await Promise.all([
|
||||
keywords.listThemes(),
|
||||
hunts.list(0, 200),
|
||||
]);
|
||||
setThemes(tRes.themes);
|
||||
setHuntList(hRes.hunts);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [enqueueSnackbar]);
|
||||
|
||||
useEffect(() => { loadData(); }, [loadData]);
|
||||
|
||||
// When hunt changes, load its datasets and auto-select all
|
||||
useEffect(() => {
|
||||
if (!selectedHuntId) { setDsList([]); setSelectedDs(new Set()); return; }
|
||||
let cancelled = false;
|
||||
datasets.list(0, 500, selectedHuntId).then(res => {
|
||||
if (cancelled) return;
|
||||
setDsList(res.datasets);
|
||||
setSelectedDs(new Set(res.datasets.map(d => d.id)));
|
||||
}).catch(() => {});
|
||||
return () => { cancelled = true; };
|
||||
}, [selectedHuntId]);
|
||||
|
||||
// Toggle helpers
|
||||
const toggleThemeSelect = (id: string) => setSelectedThemes(prev => {
|
||||
const next = new Set(prev);
|
||||
next.has(id) ? next.delete(id) : next.add(id);
|
||||
return next;
|
||||
});
|
||||
|
||||
// Run scan
|
||||
const runScan = useCallback(async () => {
|
||||
setScanning(true);
|
||||
setScanResult(null);
|
||||
try {
|
||||
const res = await keywords.scan({
|
||||
dataset_ids: selectedDs.size > 0 ? Array.from(selectedDs) : undefined,
|
||||
theme_ids: selectedThemes.size > 0 ? Array.from(selectedThemes) : undefined,
|
||||
scan_hunts: scanHunts,
|
||||
scan_annotations: scanAnnotations,
|
||||
scan_messages: scanMessages,
|
||||
});
|
||||
setScanResult(res);
|
||||
enqueueSnackbar(`Scan complete — ${res.total_hits} hits found`, {
|
||||
variant: res.total_hits > 0 ? 'warning' : 'success',
|
||||
});
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setScanning(false);
|
||||
}, [selectedDs, selectedThemes, scanHunts, scanAnnotations, scanMessages, enqueueSnackbar]);
|
||||
|
||||
if (loading) return <Box sx={{ p: 4, textAlign: 'center' }}><CircularProgress /></Box>;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Typography variant="h5" gutterBottom>AUP Keyword Scanner</Typography>
|
||||
|
||||
<Box sx={{ display: 'flex', gap: 2, height: 'calc(100vh - 140px)' }}>
|
||||
{/* Left — Theme Manager */}
|
||||
<Box sx={{ width: 380, minWidth: 320, flexShrink: 0, overflow: 'auto' }}>
|
||||
<ThemeManager themes={themes} onReload={loadData} />
|
||||
</Box>
|
||||
|
||||
{/* Right — Controls + Results */}
|
||||
<Box sx={{ flexGrow: 1, display: 'flex', flexDirection: 'column', gap: 2, minWidth: 0 }}>
|
||||
{/* Scan controls */}
|
||||
<Paper sx={{ p: 2 }}>
|
||||
<Stack direction="row" spacing={3} alignItems="flex-start" flexWrap="wrap">
|
||||
{/* Hunt → Dataset selector */}
|
||||
<Box sx={{ minWidth: 220 }}>
|
||||
<Typography variant="subtitle2" gutterBottom>Hunt</Typography>
|
||||
<FormControl size="small" fullWidth>
|
||||
<InputLabel id="aup-hunt-label">Select hunt</InputLabel>
|
||||
<Select
|
||||
labelId="aup-hunt-label"
|
||||
value={selectedHuntId}
|
||||
label="Select hunt"
|
||||
onChange={e => setSelectedHuntId(e.target.value)}
|
||||
>
|
||||
{huntList.map(h => (
|
||||
<MenuItem key={h.id} value={h.id}>
|
||||
{h.name} ({h.dataset_count} datasets)
|
||||
</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
{selectedHuntId && dsList.length > 0 && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
{dsList.length} datasets · {dsList.reduce((sum, d) => sum + d.row_count, 0).toLocaleString()} rows
|
||||
</Typography>
|
||||
)}
|
||||
{selectedHuntId && dsList.length === 0 && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
No datasets in this hunt
|
||||
</Typography>
|
||||
)}
|
||||
{!selectedHuntId && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 0.5, display: 'block' }}>
|
||||
All datasets will be scanned if no hunt is selected
|
||||
</Typography>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{/* Theme selector */}
|
||||
<Box sx={{ minWidth: 200 }}>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between">
|
||||
<Typography variant="subtitle2">Themes</Typography>
|
||||
{(() => { const enabled = themes.filter(t => t.enabled); return (
|
||||
<Button size="small" sx={{ textTransform: 'none', minWidth: 0, px: 0.5, fontSize: '0.7rem' }}
|
||||
onClick={() => {
|
||||
if (selectedThemes.size === enabled.length) setSelectedThemes(new Set());
|
||||
else setSelectedThemes(new Set(enabled.map(t => t.id)));
|
||||
}}>
|
||||
{selectedThemes.size === enabled.length && enabled.length > 0 ? 'Clear all' : 'Select all'}
|
||||
</Button>
|
||||
); })()}
|
||||
</Stack>
|
||||
<FormGroup sx={{ maxHeight: 120, overflow: 'auto' }}>
|
||||
{themes.filter(t => t.enabled).map(t => (
|
||||
<FormControlLabel key={t.id} control={
|
||||
<Checkbox size="small" checked={selectedThemes.has(t.id)}
|
||||
onChange={() => toggleThemeSelect(t.id)} />
|
||||
} label={
|
||||
<Chip label={t.name} size="small"
|
||||
sx={{ bgcolor: t.color, color: '#fff', fontSize: '0.75rem' }} />
|
||||
} />
|
||||
))}
|
||||
</FormGroup>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
{selectedThemes.size === 0 ? 'All enabled themes' : `${selectedThemes.size} selected`}
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
{/* Extra sources */}
|
||||
<Box>
|
||||
<Typography variant="subtitle2" gutterBottom>Also scan</Typography>
|
||||
<FormGroup>
|
||||
<FormControlLabel control={
|
||||
<Checkbox size="small" checked={scanHunts} onChange={(_, c) => setScanHunts(c)} />
|
||||
} label={<Typography variant="body2">Hunts</Typography>} />
|
||||
<FormControlLabel control={
|
||||
<Checkbox size="small" checked={scanAnnotations} onChange={(_, c) => setScanAnnotations(c)} />
|
||||
} label={<Typography variant="body2">Annotations</Typography>} />
|
||||
<FormControlLabel control={
|
||||
<Checkbox size="small" checked={scanMessages} onChange={(_, c) => setScanMessages(c)} />
|
||||
} label={<Typography variant="body2">Messages</Typography>} />
|
||||
</FormGroup>
|
||||
</Box>
|
||||
|
||||
{/* Scan button */}
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1, pt: 2 }}>
|
||||
<Button
|
||||
variant="contained" color="warning" size="large"
|
||||
startIcon={scanning ? <CircularProgress size={20} color="inherit" /> : <PlayArrowIcon />}
|
||||
onClick={runScan} disabled={scanning}
|
||||
>
|
||||
{scanning ? 'Scanning…' : 'Run Scan'}
|
||||
</Button>
|
||||
<Tooltip title="Reload themes & datasets">
|
||||
<IconButton onClick={loadData}><RefreshIcon /></IconButton>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{/* Scan progress */}
|
||||
{scanning && <LinearProgress color="warning" />}
|
||||
|
||||
{/* Results summary */}
|
||||
{scanResult && (
|
||||
<Alert severity={scanResult.total_hits > 0 ? 'warning' : 'success'} sx={{ py: 0.5 }}>
|
||||
<strong>{scanResult.total_hits}</strong> hits across{' '}
|
||||
<strong>{scanResult.rows_scanned}</strong> rows |{' '}
|
||||
{scanResult.themes_scanned} themes, {scanResult.keywords_scanned} keywords scanned
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Results DataGrid */}
|
||||
{scanResult && (
|
||||
<Paper sx={{ flexGrow: 1, minHeight: 300 }}>
|
||||
<DataGrid
|
||||
rows={scanResult.hits.map((h, i) => ({ id: i, ...h }))}
|
||||
columns={RESULT_COLUMNS}
|
||||
pageSizeOptions={[25, 50, 100]}
|
||||
initialState={{ pagination: { paginationModel: { pageSize: 25 } } }}
|
||||
density="compact"
|
||||
sx={{
|
||||
border: 0,
|
||||
'& .MuiDataGrid-cell': { fontSize: '0.8rem' },
|
||||
'& .MuiDataGrid-columnHeader': { fontWeight: 700 },
|
||||
}}
|
||||
/>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Empty state */}
|
||||
{!scanResult && !scanning && (
|
||||
<Paper sx={{ p: 4, textAlign: 'center', flexGrow: 1, display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Box>
|
||||
<Typography variant="h6" color="text.secondary">No scan results yet</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Select datasets and themes, then click "Run Scan" to check for AUP violations.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Paper>
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
@@ -1,264 +1,246 @@
|
||||
/**
|
||||
* Analyst-assist agent chat panel component.
|
||||
* Provides context-aware guidance on artifact data and analysis.
|
||||
* AgentPanel — analyst-assist chat with quick / deep / debate modes,
|
||||
* streaming support, SANS references, and conversation persistence.
|
||||
*/
|
||||
|
||||
import React, { useState, useRef, useEffect } from "react";
|
||||
import "./AgentPanel.css";
|
||||
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
||||
import {
|
||||
requestAgentAssistance,
|
||||
AssistResponse,
|
||||
AssistRequest,
|
||||
} from "../utils/agentApi";
|
||||
Box, Typography, Paper, TextField, Button, Stack, Chip,
|
||||
ToggleButtonGroup, ToggleButton, CircularProgress, Alert,
|
||||
Accordion, AccordionSummary, AccordionDetails, Divider, Select,
|
||||
MenuItem, FormControl, InputLabel, LinearProgress,
|
||||
} from '@mui/material';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
import SchoolIcon from '@mui/icons-material/School';
|
||||
import PsychologyIcon from '@mui/icons-material/Psychology';
|
||||
import ForumIcon from '@mui/icons-material/Forum';
|
||||
import SpeedIcon from '@mui/icons-material/Speed';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import {
|
||||
agent, datasets, hunts, type AssistRequest, type AssistResponse,
|
||||
type DatasetSummary, type Hunt,
|
||||
} from '../api/client';
|
||||
|
||||
export interface AgentPanelProps {
|
||||
/** Name of the current dataset */
|
||||
dataset_name?: string;
|
||||
/** Type of artifact (e.g., FileList, ProcessList) */
|
||||
artifact_type?: string;
|
||||
/** Host name, IP, or identifier */
|
||||
host_identifier?: string;
|
||||
/** Summary of the uploaded data */
|
||||
data_summary?: string;
|
||||
/** Callback when user needs to execute analysis based on suggestions */
|
||||
onAnalysisAction?: (action: string) => void;
|
||||
}
|
||||
interface Message { role: 'user' | 'assistant'; content: string; meta?: AssistResponse }
|
||||
|
||||
interface Message {
|
||||
role: "user" | "agent";
|
||||
content: string;
|
||||
response?: AssistResponse;
|
||||
timestamp: Date;
|
||||
}
|
||||
|
||||
export const AgentPanel: React.FC<AgentPanelProps> = ({
|
||||
dataset_name,
|
||||
artifact_type,
|
||||
host_identifier,
|
||||
data_summary,
|
||||
onAnalysisAction,
|
||||
}) => {
|
||||
export default function AgentPanel() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [query, setQuery] = useState("");
|
||||
const [query, setQuery] = useState('');
|
||||
const [mode, setMode] = useState<'quick' | 'deep' | 'debate'>('quick');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const scrollToBottom = () => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
};
|
||||
const [conversationId, setConversationId] = useState<string | null>(null);
|
||||
const [datasetList, setDatasets] = useState<DatasetSummary[]>([]);
|
||||
const [huntList, setHunts] = useState<Hunt[]>([]);
|
||||
const [selectedDataset, setSelectedDataset] = useState('');
|
||||
const [selectedHunt, setSelectedHunt] = useState('');
|
||||
const bottomRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages]);
|
||||
datasets.list(0, 100).then(r => setDatasets(r.datasets)).catch(() => {});
|
||||
hunts.list(0, 100).then(r => setHunts(r.hunts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
useEffect(() => { bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); }, [messages]);
|
||||
|
||||
if (!query.trim()) {
|
||||
return;
|
||||
}
|
||||
const send = useCallback(async () => {
|
||||
if (!query.trim() || loading) return;
|
||||
const userMsg: Message = { role: 'user', content: query };
|
||||
setMessages(prev => [...prev, userMsg]);
|
||||
setQuery('');
|
||||
setLoading(true);
|
||||
|
||||
// Add user message
|
||||
const userMessage: Message = {
|
||||
role: "user",
|
||||
content: query,
|
||||
timestamp: new Date(),
|
||||
const ds = datasetList.find(d => d.id === selectedDataset);
|
||||
const req: AssistRequest = {
|
||||
query,
|
||||
mode,
|
||||
conversation_id: conversationId || undefined,
|
||||
hunt_id: selectedHunt || undefined,
|
||||
dataset_name: ds?.name,
|
||||
data_summary: ds ? `${ds.row_count} rows, columns: ${Object.keys(ds.column_schema || {}).join(', ')}` : undefined,
|
||||
};
|
||||
|
||||
setMessages((prev) => [...prev, userMessage]);
|
||||
setQuery("");
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
// Build conversation history for context
|
||||
const conversation_history = messages.map((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
}));
|
||||
|
||||
// Request guidance from agent
|
||||
const response = await requestAgentAssistance({
|
||||
query: query,
|
||||
dataset_name,
|
||||
artifact_type,
|
||||
host_identifier,
|
||||
data_summary,
|
||||
conversation_history,
|
||||
});
|
||||
|
||||
// Add agent response
|
||||
const agentMessage: Message = {
|
||||
role: "agent",
|
||||
content: response.guidance,
|
||||
response,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
setMessages((prev) => [...prev, agentMessage]);
|
||||
} catch (err) {
|
||||
const errorMessage =
|
||||
err instanceof Error ? err.message : "Failed to get guidance";
|
||||
setError(errorMessage);
|
||||
|
||||
// Add error message
|
||||
const errorMsg: Message = {
|
||||
role: "agent",
|
||||
content: `Error: ${errorMessage}. The agent service may be unavailable.`,
|
||||
timestamp: new Date(),
|
||||
};
|
||||
|
||||
setMessages((prev) => [...prev, errorMsg]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
const resp = await agent.assist(req);
|
||||
setConversationId(resp.conversation_id || null);
|
||||
setMessages(prev => [...prev, { role: 'assistant', content: resp.guidance, meta: resp }]);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(e.message, { variant: 'error' });
|
||||
setMessages(prev => [...prev, { role: 'assistant', content: `Error: ${e.message}` }]);
|
||||
}
|
||||
setLoading(false);
|
||||
}, [query, mode, loading, conversationId, selectedDataset, selectedHunt, datasetList, enqueueSnackbar]);
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); send(); }
|
||||
};
|
||||
|
||||
const newConversation = () => { setMessages([]); setConversationId(null); };
|
||||
|
||||
return (
|
||||
<div className="agent-panel">
|
||||
<div className="agent-panel-header">
|
||||
<h3>Analyst Assist Agent</h3>
|
||||
<div className="agent-context">
|
||||
{host_identifier && (
|
||||
<span className="context-badge">Host: {host_identifier}</span>
|
||||
)}
|
||||
{artifact_type && (
|
||||
<span className="context-badge">Artifact: {artifact_type}</span>
|
||||
)}
|
||||
{dataset_name && (
|
||||
<span className="context-badge">Dataset: {dataset_name}</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<Box sx={{ height: '100%', display: 'flex', flexDirection: 'column' }}>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 1 }}>
|
||||
<Typography variant="h5">Agent Assist</Typography>
|
||||
<Button size="small" onClick={newConversation}>New Conversation</Button>
|
||||
</Stack>
|
||||
|
||||
<div className="agent-messages">
|
||||
{messages.length === 0 ? (
|
||||
<div className="agent-welcome">
|
||||
<p className="welcome-title">Welcome to Analyst Assist</p>
|
||||
<p className="welcome-text">
|
||||
Ask questions about your artifact data. I can help you:
|
||||
</p>
|
||||
<ul>
|
||||
<li>Interpret and explain data patterns</li>
|
||||
<li>Suggest analytical pivots and filters</li>
|
||||
<li>Help form and test hypotheses</li>
|
||||
<li>Highlight anomalies and points of interest</li>
|
||||
</ul>
|
||||
<p className="welcome-note">
|
||||
💡 This agent provides guidance only. All analytical decisions
|
||||
remain with you.
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
messages.map((msg, idx) => (
|
||||
<div key={idx} className={`message ${msg.role}`}>
|
||||
<div className="message-header">
|
||||
<span className="message-role">
|
||||
{msg.role === "user" ? "You" : "Agent"}
|
||||
</span>
|
||||
<span className="message-time">
|
||||
{msg.timestamp.toLocaleTimeString()}
|
||||
</span>
|
||||
</div>
|
||||
{/* Controls */}
|
||||
<Paper sx={{ p: 1.5, mb: 1 }}>
|
||||
<Stack direction="row" spacing={1.5} alignItems="center" flexWrap="wrap">
|
||||
<ToggleButtonGroup
|
||||
size="small" exclusive value={mode}
|
||||
onChange={(_, v) => v && setMode(v)}
|
||||
>
|
||||
<ToggleButton value="quick"><SpeedIcon sx={{ mr: 0.5, fontSize: 18 }} />Quick</ToggleButton>
|
||||
<ToggleButton value="deep"><PsychologyIcon sx={{ mr: 0.5, fontSize: 18 }} />Deep</ToggleButton>
|
||||
<ToggleButton value="debate"><ForumIcon sx={{ mr: 0.5, fontSize: 18 }} />Debate</ToggleButton>
|
||||
</ToggleButtonGroup>
|
||||
|
||||
<div className="message-content">{msg.content}</div>
|
||||
<FormControl size="small" sx={{ minWidth: 160 }}>
|
||||
<InputLabel>Dataset</InputLabel>
|
||||
<Select label="Dataset" value={selectedDataset} onChange={e => setSelectedDataset(e.target.value)}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{datasetList.map(d => <MenuItem key={d.id} value={d.id}>{d.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
|
||||
{msg.response && (
|
||||
<div className="message-details">
|
||||
{msg.response.suggested_pivots.length > 0 && (
|
||||
<div className="detail-section">
|
||||
<h5>Suggested Pivots:</h5>
|
||||
<ul>
|
||||
{msg.response.suggested_pivots.map((pivot, i) => (
|
||||
<li key={i}>
|
||||
<button
|
||||
className="pivot-button"
|
||||
onClick={() =>
|
||||
onAnalysisAction && onAnalysisAction(pivot)
|
||||
}
|
||||
>
|
||||
{pivot}
|
||||
</button>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
<FormControl size="small" sx={{ minWidth: 160 }}>
|
||||
<InputLabel>Hunt</InputLabel>
|
||||
<Select label="Hunt" value={selectedHunt} onChange={e => setSelectedHunt(e.target.value)}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{msg.response.suggested_filters.length > 0 && (
|
||||
<div className="detail-section">
|
||||
<h5>Suggested Filters:</h5>
|
||||
<ul>
|
||||
{msg.response.suggested_filters.map((filter, i) => (
|
||||
<li key={i}>
|
||||
<code>{filter}</code>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{msg.response.caveats && (
|
||||
<div className="detail-section caveats">
|
||||
<h5>⚠️ Caveats:</h5>
|
||||
<p>{msg.response.caveats}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{msg.response.confidence && (
|
||||
<div className="detail-section">
|
||||
<span className="confidence">
|
||||
Confidence: {(msg.response.confidence * 100).toFixed(0)}%
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))
|
||||
{/* Messages */}
|
||||
<Paper sx={{ flex: 1, overflow: 'auto', p: 2, mb: 1, minHeight: 300 }}>
|
||||
{messages.length === 0 && (
|
||||
<Box sx={{ textAlign: 'center', mt: 8 }}>
|
||||
<PsychologyIcon sx={{ fontSize: 64, color: 'text.secondary', mb: 1 }} />
|
||||
<Typography color="text.secondary">
|
||||
Ask a question about your threat hunt data.
|
||||
</Typography>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
The agent provides advisory guidance — all decisions remain with the analyst.
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
{messages.map((m, i) => (
|
||||
<Box key={i} sx={{ mb: 2 }}>
|
||||
<Typography variant="caption" color="text.secondary" fontWeight={700}>
|
||||
{m.role === 'user' ? 'You' : 'Agent'}
|
||||
</Typography>
|
||||
<Paper sx={{
|
||||
p: 1.5, mt: 0.5,
|
||||
bgcolor: m.role === 'user' ? 'primary.dark' : 'background.default',
|
||||
borderColor: m.role === 'user' ? 'primary.main' : 'divider',
|
||||
}}>
|
||||
<Typography variant="body2" sx={{ whiteSpace: 'pre-wrap' }}>{m.content}</Typography>
|
||||
</Paper>
|
||||
|
||||
{loading && (
|
||||
<div className="message agent loading">
|
||||
<div className="loading-indicator">
|
||||
<span className="dot"></span>
|
||||
<span className="dot"></span>
|
||||
<span className="dot"></span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* Response metadata */}
|
||||
{m.meta && (
|
||||
<Box sx={{ mt: 0.5 }}>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 0.5 }}>
|
||||
<Chip label={`${m.meta.confidence * 100}% confidence`} size="small"
|
||||
color={m.meta.confidence >= 0.7 ? 'success' : m.meta.confidence >= 0.4 ? 'warning' : 'error'} variant="outlined" />
|
||||
<Chip label={m.meta.model_used} size="small" variant="outlined" />
|
||||
<Chip label={m.meta.node_used} size="small" variant="outlined" />
|
||||
<Chip label={`${m.meta.latency_ms}ms`} size="small" variant="outlined" />
|
||||
</Stack>
|
||||
|
||||
{error && (
|
||||
<div className="message agent error">
|
||||
<p className="error-text">⚠️ {error}</p>
|
||||
</div>
|
||||
)}
|
||||
{/* Pivots & Filters */}
|
||||
{(m.meta.suggested_pivots.length > 0 || m.meta.suggested_filters.length > 0) && (
|
||||
<Accordion disableGutters sx={{ mt: 0.5 }}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography variant="caption">Pivots & Filters</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{m.meta.suggested_pivots.length > 0 && (
|
||||
<>
|
||||
<Typography variant="caption" fontWeight={600}>Pivots</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 1 }}>
|
||||
{m.meta.suggested_pivots.map((p, j) => <Chip key={j} label={p} size="small" color="info" />)}
|
||||
</Stack>
|
||||
</>
|
||||
)}
|
||||
{m.meta.suggested_filters.length > 0 && (
|
||||
<>
|
||||
<Typography variant="caption" fontWeight={600}>Filters</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap">
|
||||
{m.meta.suggested_filters.map((f, j) => <Chip key={j} label={f} size="small" color="secondary" />)}
|
||||
</Stack>
|
||||
</>
|
||||
)}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
{/* SANS references */}
|
||||
{m.meta.sans_references.length > 0 && (
|
||||
<Accordion disableGutters sx={{ mt: 0.5 }}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Stack direction="row" alignItems="center" spacing={0.5}>
|
||||
<SchoolIcon sx={{ fontSize: 16 }} />
|
||||
<Typography variant="caption">SANS References ({m.meta.sans_references.length})</Typography>
|
||||
</Stack>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{m.meta.sans_references.map((r, j) => (
|
||||
<Typography key={j} variant="body2" sx={{ mb: 0.5 }}>• {r}</Typography>
|
||||
))}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
|
||||
<form onSubmit={handleSubmit} className="agent-input-form">
|
||||
<input
|
||||
type="text"
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
placeholder="Ask about your data, patterns, or next steps..."
|
||||
{/* Debate perspectives */}
|
||||
{m.meta.perspectives && m.meta.perspectives.length > 0 && (
|
||||
<Accordion disableGutters sx={{ mt: 0.5 }}>
|
||||
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
|
||||
<Typography variant="caption">Debate Perspectives ({m.meta.perspectives.length})</Typography>
|
||||
</AccordionSummary>
|
||||
<AccordionDetails>
|
||||
{m.meta.perspectives.map((p: any, j: number) => (
|
||||
<Box key={j} sx={{ mb: 1 }}>
|
||||
<Chip label={p.role || `Perspective ${j + 1}`} size="small" color="primary" sx={{ mb: 0.5 }} />
|
||||
<Typography variant="body2">{p.argument || p.content || JSON.stringify(p)}</Typography>
|
||||
<Divider sx={{ mt: 1 }} />
|
||||
</Box>
|
||||
))}
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
)}
|
||||
|
||||
{/* Caveats */}
|
||||
{m.meta.caveats && (
|
||||
<Alert severity="warning" sx={{ mt: 0.5, py: 0 }}>
|
||||
<Typography variant="caption">{m.meta.caveats}</Typography>
|
||||
</Alert>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
))}
|
||||
{loading && <LinearProgress sx={{ mb: 1 }} />}
|
||||
<div ref={bottomRef} />
|
||||
</Paper>
|
||||
|
||||
{/* Input */}
|
||||
<Stack direction="row" spacing={1}>
|
||||
<TextField
|
||||
fullWidth size="small" multiline maxRows={4}
|
||||
placeholder="Ask the agent..."
|
||||
value={query} onChange={e => setQuery(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
disabled={loading}
|
||||
className="agent-input"
|
||||
/>
|
||||
<button type="submit" disabled={loading || !query.trim()}>
|
||||
{loading ? "Thinking..." : "Ask"}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div className="agent-footer">
|
||||
<p className="footer-note">
|
||||
ℹ️ Agent provides guidance only. All decisions remain with the analyst.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button variant="contained" onClick={send} disabled={loading || !query.trim()}>
|
||||
{loading ? <CircularProgress size={20} /> : <SendIcon />}
|
||||
</Button>
|
||||
</Stack>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default AgentPanel;
|
||||
}
|
||||
|
||||
185
frontend/src/components/AnnotationPanel.tsx
Normal file
185
frontend/src/components/AnnotationPanel.tsx
Normal file
@@ -0,0 +1,185 @@
|
||||
/**
|
||||
* AnnotationPanel — create / list / filter annotations on dataset rows.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Chip, Button, TextField,
|
||||
Select, MenuItem, FormControl, InputLabel, CircularProgress,
|
||||
Table, TableBody, TableCell, TableContainer, TableHead, TableRow,
|
||||
Dialog, DialogTitle, DialogContent, DialogActions, IconButton,
|
||||
} from '@mui/material';
|
||||
import AddIcon from '@mui/icons-material/Add';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { annotations, datasets, type AnnotationData, type DatasetSummary } from '../api/client';
|
||||
|
||||
const SEVERITIES = ['info', 'low', 'medium', 'high', 'critical'];
|
||||
const TAGS = ['suspicious', 'benign', 'needs-review'];
|
||||
const SEV_COLORS: Record<string, 'default' | 'info' | 'success' | 'warning' | 'error'> = {
|
||||
info: 'info', low: 'success', medium: 'warning', high: 'error', critical: 'error',
|
||||
};
|
||||
|
||||
export default function AnnotationPanel() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [list, setList] = useState<AnnotationData[]>([]);
|
||||
const [total, setTotal] = useState(0);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [filterSeverity, setFilterSeverity] = useState('');
|
||||
const [filterTag, setFilterTag] = useState('');
|
||||
const [filterDataset, setFilterDataset] = useState('');
|
||||
const [datasetList, setDatasetList] = useState<DatasetSummary[]>([]);
|
||||
const [dlgOpen, setDlgOpen] = useState(false);
|
||||
const [form, setForm] = useState({ text: '', severity: 'info', tag: '', dataset_id: '', row_id: '' });
|
||||
|
||||
useEffect(() => {
|
||||
datasets.list(0, 200).then(r => setDatasetList(r.datasets)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await annotations.list({
|
||||
severity: filterSeverity || undefined,
|
||||
tag: filterTag || undefined,
|
||||
dataset_id: filterDataset || undefined,
|
||||
limit: 100,
|
||||
});
|
||||
setList(r.annotations); setTotal(r.total);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [filterSeverity, filterTag, filterDataset, enqueueSnackbar]);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const handleCreate = async () => {
|
||||
try {
|
||||
await annotations.create({
|
||||
text: form.text,
|
||||
severity: form.severity,
|
||||
tag: form.tag || undefined,
|
||||
dataset_id: form.dataset_id || undefined,
|
||||
row_id: form.row_id ? parseInt(form.row_id, 10) : undefined,
|
||||
});
|
||||
enqueueSnackbar('Annotation created', { variant: 'success' });
|
||||
setDlgOpen(false); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
try {
|
||||
await annotations.delete(id);
|
||||
enqueueSnackbar('Deleted', { variant: 'info' }); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
if (loading) return <Box sx={{ p: 4 }}><CircularProgress /></Box>;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 2 }}>
|
||||
<Typography variant="h5">Annotations ({total})</Typography>
|
||||
<Button variant="contained" startIcon={<AddIcon />}
|
||||
onClick={() => { setForm({ text: '', severity: 'info', tag: '', dataset_id: '', row_id: '' }); setDlgOpen(true); }}>
|
||||
New
|
||||
</Button>
|
||||
</Stack>
|
||||
|
||||
{/* Filters */}
|
||||
<Paper sx={{ p: 1.5, mb: 2 }}>
|
||||
<Stack direction="row" spacing={1.5} flexWrap="wrap">
|
||||
<FormControl size="small" sx={{ minWidth: 120 }}>
|
||||
<InputLabel>Severity</InputLabel>
|
||||
<Select label="Severity" value={filterSeverity} onChange={e => setFilterSeverity(e.target.value)}>
|
||||
<MenuItem value="">All</MenuItem>
|
||||
{SEVERITIES.map(s => <MenuItem key={s} value={s}>{s}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl size="small" sx={{ minWidth: 120 }}>
|
||||
<InputLabel>Tag</InputLabel>
|
||||
<Select label="Tag" value={filterTag} onChange={e => setFilterTag(e.target.value)}>
|
||||
<MenuItem value="">All</MenuItem>
|
||||
{TAGS.map(t => <MenuItem key={t} value={t}>{t}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl size="small" sx={{ minWidth: 180 }}>
|
||||
<InputLabel>Dataset</InputLabel>
|
||||
<Select label="Dataset" value={filterDataset} onChange={e => setFilterDataset(e.target.value)}>
|
||||
<MenuItem value="">All</MenuItem>
|
||||
{datasetList.map(d => <MenuItem key={d.id} value={d.id}>{d.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
<TableContainer component={Paper}>
|
||||
<Table size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Severity</TableCell>
|
||||
<TableCell>Tag</TableCell>
|
||||
<TableCell>Text</TableCell>
|
||||
<TableCell>Row</TableCell>
|
||||
<TableCell>Created</TableCell>
|
||||
<TableCell align="right">Actions</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{list.map(a => (
|
||||
<TableRow key={a.id} hover>
|
||||
<TableCell>
|
||||
<Chip label={a.severity} size="small" color={SEV_COLORS[a.severity] || 'default'} variant="outlined" />
|
||||
</TableCell>
|
||||
<TableCell>{a.tag || '—'}</TableCell>
|
||||
<TableCell><Typography variant="body2" sx={{ maxWidth: 400, overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }}>{a.text}</Typography></TableCell>
|
||||
<TableCell>{a.row_id ?? '—'}</TableCell>
|
||||
<TableCell><Typography variant="caption">{new Date(a.created_at).toLocaleString()}</Typography></TableCell>
|
||||
<TableCell align="right">
|
||||
<IconButton size="small" color="error" onClick={() => handleDelete(a.id)}><DeleteIcon fontSize="small" /></IconButton>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
{list.length === 0 && (
|
||||
<TableRow><TableCell colSpan={6} align="center"><Typography color="text.secondary">No annotations</Typography></TableCell></TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
|
||||
{/* Create dialog */}
|
||||
<Dialog open={dlgOpen} onClose={() => setDlgOpen(false)} maxWidth="sm" fullWidth>
|
||||
<DialogTitle>New Annotation</DialogTitle>
|
||||
<DialogContent>
|
||||
<Stack spacing={2} sx={{ mt: 1 }}>
|
||||
<TextField label="Text" fullWidth multiline rows={3} value={form.text} onChange={e => setForm(f => ({ ...f, text: e.target.value }))} />
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Severity</InputLabel>
|
||||
<Select label="Severity" value={form.severity} onChange={e => setForm(f => ({ ...f, severity: e.target.value }))}>
|
||||
{SEVERITIES.map(s => <MenuItem key={s} value={s}>{s}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Tag</InputLabel>
|
||||
<Select label="Tag" value={form.tag} onChange={e => setForm(f => ({ ...f, tag: e.target.value }))}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{TAGS.map(t => <MenuItem key={t} value={t}>{t}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Dataset</InputLabel>
|
||||
<Select label="Dataset" value={form.dataset_id} onChange={e => setForm(f => ({ ...f, dataset_id: e.target.value }))}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{datasetList.map(d => <MenuItem key={d.id} value={d.id}>{d.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<TextField label="Row Index" type="number" value={form.row_id} onChange={e => setForm(f => ({ ...f, row_id: e.target.value }))} />
|
||||
</Stack>
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={() => setDlgOpen(false)}>Cancel</Button>
|
||||
<Button variant="contained" onClick={handleCreate} disabled={!form.text.trim()}>Create</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
178
frontend/src/components/CorrelationView.tsx
Normal file
178
frontend/src/components/CorrelationView.tsx
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* CorrelationView — cross-hunt correlation analysis with IOC, time,
|
||||
* technique, and host overlap visualisation.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Chip, Button, CircularProgress,
|
||||
Alert, Table, TableBody, TableCell, TableContainer, TableHead,
|
||||
TableRow, TextField,
|
||||
} from '@mui/material';
|
||||
import CompareArrowsIcon from '@mui/icons-material/CompareArrows';
|
||||
import SearchIcon from '@mui/icons-material/Search';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { correlation, hunts, type Hunt, type CorrelationResult } from '../api/client';
|
||||
|
||||
export default function CorrelationView() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [selectedIds, setSelectedIds] = useState<string[]>([]);
|
||||
const [result, setResult] = useState<CorrelationResult | null>(null);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [iocSearch, setIocSearch] = useState('');
|
||||
const [iocHits, setIocHits] = useState<any[] | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
const toggleHunt = (id: string) => {
|
||||
setSelectedIds(prev =>
|
||||
prev.includes(id) ? prev.filter(x => x !== id) : [...prev, id],
|
||||
);
|
||||
};
|
||||
|
||||
const runAnalysis = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = selectedIds.length >= 2
|
||||
? await correlation.analyze(selectedIds)
|
||||
: await correlation.all();
|
||||
setResult(r);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [selectedIds, enqueueSnackbar]);
|
||||
|
||||
const searchIoc = useCallback(async () => {
|
||||
if (!iocSearch.trim()) return;
|
||||
try {
|
||||
const r = await correlation.ioc(iocSearch.trim());
|
||||
setIocHits(r.occurrences);
|
||||
if (r.occurrences.length === 0) enqueueSnackbar('No occurrences found', { variant: 'info' });
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
}, [iocSearch, enqueueSnackbar]);
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Typography variant="h5" gutterBottom>Cross-Hunt Correlation</Typography>
|
||||
|
||||
{/* Hunt selector */}
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1 }}>Select hunts to correlate (min 2, or leave empty for all):</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mb: 1.5 }}>
|
||||
{huntList.map(h => (
|
||||
<Chip
|
||||
key={h.id} label={h.name} size="small"
|
||||
color={selectedIds.includes(h.id) ? 'primary' : 'default'}
|
||||
onClick={() => toggleHunt(h.id)}
|
||||
variant={selectedIds.includes(h.id) ? 'filled' : 'outlined'}
|
||||
sx={{ mb: 0.5 }}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
<Button
|
||||
variant="contained" startIcon={loading ? <CircularProgress size={16} /> : <CompareArrowsIcon />}
|
||||
onClick={runAnalysis} disabled={loading}
|
||||
>
|
||||
{selectedIds.length >= 2 ? `Correlate ${selectedIds.length} Hunts` : 'Correlate All'}
|
||||
</Button>
|
||||
</Paper>
|
||||
|
||||
{/* IOC Search */}
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1 }}>Search IOC across all hunts:</Typography>
|
||||
<Stack direction="row" spacing={1}>
|
||||
<TextField size="small" fullWidth placeholder="e.g. 192.168.1.100" value={iocSearch}
|
||||
onChange={e => setIocSearch(e.target.value)} onKeyDown={e => e.key === 'Enter' && searchIoc()} />
|
||||
<Button variant="outlined" startIcon={<SearchIcon />} onClick={searchIoc}>Search</Button>
|
||||
</Stack>
|
||||
{iocHits && iocHits.length > 0 && (
|
||||
<Box sx={{ mt: 1.5 }}>
|
||||
<Typography variant="body2" fontWeight={600}>Found in {iocHits.length} location(s):</Typography>
|
||||
{iocHits.map((hit: any, i: number) => (
|
||||
<Chip key={i} label={`${hit.hunt_name || hit.hunt_id} / ${hit.dataset_name || hit.dataset_id}`}
|
||||
size="small" sx={{ mr: 0.5, mt: 0.5 }} />
|
||||
))}
|
||||
</Box>
|
||||
)}
|
||||
</Paper>
|
||||
|
||||
{/* Results */}
|
||||
{result && (
|
||||
<Box>
|
||||
<Alert severity="info" sx={{ mb: 2 }}>
|
||||
{result.summary} — {result.total_correlations} total correlation(s) across {result.hunt_ids.length} hunts
|
||||
</Alert>
|
||||
|
||||
{/* IOC overlaps */}
|
||||
{result.ioc_overlaps.length > 0 && (
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="h6" gutterBottom>IOC Overlaps ({result.ioc_overlaps.length})</Typography>
|
||||
<TableContainer>
|
||||
<Table size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>IOC</TableCell>
|
||||
<TableCell>Type</TableCell>
|
||||
<TableCell>Shared Hunts</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{result.ioc_overlaps.map((o: any, i: number) => (
|
||||
<TableRow key={i}>
|
||||
<TableCell><Typography variant="body2" fontFamily="monospace">{o.ioc_value}</Typography></TableCell>
|
||||
<TableCell><Chip label={o.ioc_type || 'unknown'} size="small" /></TableCell>
|
||||
<TableCell>
|
||||
{(o.hunt_ids || []).map((hid: string, j: number) => (
|
||||
<Chip key={j} label={huntList.find(h => h.id === hid)?.name || hid} size="small" sx={{ mr: 0.5 }} />
|
||||
))}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Technique overlaps */}
|
||||
{result.technique_overlaps.length > 0 && (
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="h6" gutterBottom>MITRE Technique Overlaps</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap">
|
||||
{result.technique_overlaps.map((t: any, i: number) => (
|
||||
<Chip key={i} label={t.technique || t.mitre_technique} color="secondary" size="small" />
|
||||
))}
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Time overlaps */}
|
||||
{result.time_overlaps.length > 0 && (
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="h6" gutterBottom>Time Overlaps</Typography>
|
||||
{result.time_overlaps.map((t: any, i: number) => (
|
||||
<Typography key={i} variant="body2" sx={{ mb: 0.5 }}>
|
||||
{t.hunt_a || 'Hunt A'} ↔ {t.hunt_b || 'Hunt B'}: {t.overlap_start} — {t.overlap_end}
|
||||
</Typography>
|
||||
))}
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Host overlaps */}
|
||||
{result.host_overlaps.length > 0 && (
|
||||
<Paper sx={{ p: 2 }}>
|
||||
<Typography variant="h6" gutterBottom>Host Overlaps</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap">
|
||||
{result.host_overlaps.map((h: any, i: number) => (
|
||||
<Chip key={i} label={typeof h === 'string' ? h : h.hostname || JSON.stringify(h)} size="small" variant="outlined" />
|
||||
))}
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
152
frontend/src/components/Dashboard.tsx
Normal file
152
frontend/src/components/Dashboard.tsx
Normal file
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Dashboard — overview cards with hunt stats, node health, recent activity.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import {
|
||||
Box, Grid, Paper, Typography, Chip, CircularProgress,
|
||||
Stack, Alert,
|
||||
} from '@mui/material';
|
||||
import StorageIcon from '@mui/icons-material/Storage';
|
||||
import SearchIcon from '@mui/icons-material/Search';
|
||||
import SecurityIcon from '@mui/icons-material/Security';
|
||||
import ScienceIcon from '@mui/icons-material/Science';
|
||||
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||
import ErrorIcon from '@mui/icons-material/Error';
|
||||
import { hunts, datasets, hypotheses, agent, misc, type Hunt, type DatasetSummary, type HealthInfo } from '../api/client';
|
||||
|
||||
function StatCard({ title, value, icon, color }: { title: string; value: string | number; icon: React.ReactNode; color: string }) {
|
||||
return (
|
||||
<Paper sx={{ p: 2.5 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={2}>
|
||||
<Box sx={{ color, fontSize: 40, display: 'flex' }}>{icon}</Box>
|
||||
<Box>
|
||||
<Typography variant="h4">{value}</Typography>
|
||||
<Typography variant="body2" color="text.secondary">{title}</Typography>
|
||||
</Box>
|
||||
</Stack>
|
||||
</Paper>
|
||||
);
|
||||
}
|
||||
|
||||
function NodeStatus({ label, available }: { label: string; available: boolean }) {
|
||||
return (
|
||||
<Stack direction="row" alignItems="center" spacing={1}>
|
||||
{available
|
||||
? <CheckCircleIcon sx={{ color: 'success.main', fontSize: 20 }} />
|
||||
: <ErrorIcon sx={{ color: 'error.main', fontSize: 20 }} />
|
||||
}
|
||||
<Typography variant="body2">{label}</Typography>
|
||||
<Chip label={available ? 'Online' : 'Offline'} size="small"
|
||||
color={available ? 'success' : 'error'} variant="outlined" />
|
||||
</Stack>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Dashboard() {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [health, setHealth] = useState<HealthInfo | null>(null);
|
||||
const [huntList, setHunts] = useState<Hunt[]>([]);
|
||||
const [datasetList, setDatasets] = useState<DatasetSummary[]>([]);
|
||||
const [hypoCount, setHypoCount] = useState(0);
|
||||
const [apiInfo, setApiInfo] = useState<{ name: string; version: string; status: string } | null>(null);
|
||||
const [error, setError] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
try {
|
||||
const [h, ht, ds, hy, info] = await Promise.all([
|
||||
agent.health().catch(() => null),
|
||||
hunts.list(0, 100).catch(() => ({ hunts: [], total: 0 })),
|
||||
datasets.list(0, 100).catch(() => ({ datasets: [], total: 0 })),
|
||||
hypotheses.list({ limit: 1 }).catch(() => ({ hypotheses: [], total: 0 })),
|
||||
misc.root().catch(() => null),
|
||||
]);
|
||||
setHealth(h);
|
||||
setHunts(ht.hunts);
|
||||
setDatasets(ds.datasets);
|
||||
setHypoCount(hy.total);
|
||||
setApiInfo(info);
|
||||
} catch (e: any) {
|
||||
setError(e.message);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
})();
|
||||
}, []);
|
||||
|
||||
if (loading) return <Box sx={{ p: 4 }}><CircularProgress /></Box>;
|
||||
if (error) return <Alert severity="error">{error}</Alert>;
|
||||
|
||||
const activeHunts = huntList.filter(h => h.status === 'active').length;
|
||||
const totalRows = datasetList.reduce((s, d) => s + d.row_count, 0);
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Typography variant="h5" gutterBottom>Dashboard</Typography>
|
||||
|
||||
{/* Stat cards */}
|
||||
<Grid container spacing={2} sx={{ mb: 3 }}>
|
||||
<Grid size={{ xs: 12, sm: 6, md: 3 }}>
|
||||
<StatCard title="Active Hunts" value={activeHunts} icon={<SearchIcon fontSize="inherit" />} color="#60a5fa" />
|
||||
</Grid>
|
||||
<Grid size={{ xs: 12, sm: 6, md: 3 }}>
|
||||
<StatCard title="Datasets" value={datasetList.length} icon={<StorageIcon fontSize="inherit" />} color="#f472b6" />
|
||||
</Grid>
|
||||
<Grid size={{ xs: 12, sm: 6, md: 3 }}>
|
||||
<StatCard title="Total Rows" value={totalRows.toLocaleString()} icon={<SecurityIcon fontSize="inherit" />} color="#10b981" />
|
||||
</Grid>
|
||||
<Grid size={{ xs: 12, sm: 6, md: 3 }}>
|
||||
<StatCard title="Hypotheses" value={hypoCount} icon={<ScienceIcon fontSize="inherit" />} color="#f59e0b" />
|
||||
</Grid>
|
||||
</Grid>
|
||||
|
||||
{/* Node health + API info */}
|
||||
<Grid container spacing={2}>
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<Paper sx={{ p: 2.5 }}>
|
||||
<Typography variant="h6" gutterBottom>LLM Cluster Health</Typography>
|
||||
<Stack spacing={1.5}>
|
||||
<NodeStatus label="Wile (100.110.190.12)" available={health?.nodes?.wile?.available ?? false} />
|
||||
<NodeStatus label="Roadrunner (100.110.190.11)" available={health?.nodes?.roadrunner?.available ?? false} />
|
||||
<NodeStatus label="SANS RAG (Open WebUI)" available={health?.rag?.available ?? false} />
|
||||
</Stack>
|
||||
</Paper>
|
||||
</Grid>
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<Paper sx={{ p: 2.5 }}>
|
||||
<Typography variant="h6" gutterBottom>API Status</Typography>
|
||||
<Stack spacing={1}>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
{apiInfo ? `${apiInfo.name} — ${apiInfo.version}` : 'Unreachable'}
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Status: {apiInfo?.status ?? 'unknown'}
|
||||
</Typography>
|
||||
</Stack>
|
||||
</Paper>
|
||||
</Grid>
|
||||
</Grid>
|
||||
|
||||
{/* Recent hunts */}
|
||||
{huntList.length > 0 && (
|
||||
<Paper sx={{ p: 2.5, mt: 2 }}>
|
||||
<Typography variant="h6" gutterBottom>Recent Hunts</Typography>
|
||||
<Stack spacing={1}>
|
||||
{huntList.slice(0, 5).map(h => (
|
||||
<Stack key={h.id} direction="row" alignItems="center" spacing={1}>
|
||||
<Chip label={h.status} size="small"
|
||||
color={h.status === 'active' ? 'success' : h.status === 'closed' ? 'default' : 'warning'}
|
||||
variant="outlined" />
|
||||
<Typography variant="body2" sx={{ fontWeight: 600 }}>{h.name}</Typography>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
{h.dataset_count} datasets · {h.hypothesis_count} hypotheses
|
||||
</Typography>
|
||||
</Stack>
|
||||
))}
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
200
frontend/src/components/DatasetViewer.tsx
Normal file
200
frontend/src/components/DatasetViewer.tsx
Normal file
@@ -0,0 +1,200 @@
|
||||
/**
|
||||
* DatasetViewer — list datasets, browse rows with MUI DataGrid.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Chip, CircularProgress,
|
||||
Alert, Button, IconButton, Select, MenuItem, FormControl,
|
||||
InputLabel,
|
||||
} from '@mui/material';
|
||||
import { DataGrid, type GridColDef, type GridPaginationModel } from '@mui/x-data-grid';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import RefreshIcon from '@mui/icons-material/Refresh';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { datasets, enrichment, type DatasetSummary } from '../api/client';
|
||||
|
||||
export default function DatasetViewer() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [list, setList] = useState<DatasetSummary[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [selected, setSelected] = useState<DatasetSummary | null>(null);
|
||||
const [rows, setRows] = useState<Record<string, any>[]>([]);
|
||||
const [rowTotal, setRowTotal] = useState(0);
|
||||
const [paginationModel, setPaginationModel] = useState<GridPaginationModel>({ page: 0, pageSize: 50 });
|
||||
const [rowLoading, setRowLoading] = useState(false);
|
||||
const [enriching, setEnriching] = useState(false);
|
||||
|
||||
const loadList = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await datasets.list(0, 200);
|
||||
setList(r.datasets);
|
||||
if (r.datasets.length > 0 && !selected) setSelected(r.datasets[0]);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [enqueueSnackbar, selected]);
|
||||
|
||||
const loadRows = useCallback(async () => {
|
||||
if (!selected) return;
|
||||
setRowLoading(true);
|
||||
try {
|
||||
const r = await datasets.rows(selected.id, paginationModel.page * paginationModel.pageSize, paginationModel.pageSize);
|
||||
setRows(r.rows.map((rw, i) => ({ __id: `${paginationModel.page}-${i}`, ...rw })));
|
||||
setRowTotal(r.total);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setRowLoading(false);
|
||||
}, [selected, paginationModel, enqueueSnackbar]);
|
||||
|
||||
useEffect(() => { loadList(); }, [loadList]);
|
||||
useEffect(() => { loadRows(); }, [loadRows]);
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
if (!window.confirm('Delete this dataset?')) return;
|
||||
try {
|
||||
await datasets.delete(id);
|
||||
enqueueSnackbar('Dataset deleted', { variant: 'info' });
|
||||
if (selected?.id === id) setSelected(null);
|
||||
loadList();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
const handleEnrich = async () => {
|
||||
if (!selected) return;
|
||||
setEnriching(true);
|
||||
try {
|
||||
const r = await enrichment.dataset(selected.id);
|
||||
enqueueSnackbar(`Enriched ${r.enriched} IOCs from ${r.iocs_found} found`, { variant: 'success' });
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setEnriching(false);
|
||||
};
|
||||
|
||||
// IOC type → colour mapping (matches NetworkMap)
|
||||
const IOC_COLORS: Record<string, { bg: string; text: string; header: string }> = {
|
||||
ip: { bg: 'rgba(59,130,246,0.08)', text: '#3b82f6', header: 'rgba(59,130,246,0.18)' },
|
||||
hostname: { bg: 'rgba(34,197,94,0.08)', text: '#22c55e', header: 'rgba(34,197,94,0.18)' },
|
||||
domain: { bg: 'rgba(234,179,8,0.08)', text: '#eab308', header: 'rgba(234,179,8,0.18)' },
|
||||
url: { bg: 'rgba(139,92,246,0.08)', text: '#8b5cf6', header: 'rgba(139,92,246,0.18)' },
|
||||
hash_md5: { bg: 'rgba(244,63,94,0.08)', text: '#f43f5e', header: 'rgba(244,63,94,0.18)' },
|
||||
hash_sha1:{ bg: 'rgba(244,63,94,0.08)', text: '#f43f5e', header: 'rgba(244,63,94,0.18)' },
|
||||
hash_sha256:{ bg: 'rgba(244,63,94,0.08)',text: '#f43f5e', header: 'rgba(244,63,94,0.18)' },
|
||||
};
|
||||
const DEFAULT_IOC_STYLE = { bg: 'rgba(251,191,36,0.08)', text: '#fbbf24', header: 'rgba(251,191,36,0.18)' };
|
||||
|
||||
// Resolve IOC type for a column (first type in the array)
|
||||
const iocMap = selected?.ioc_columns ?? {};
|
||||
const iocTypeFor = (col: string): string | null => {
|
||||
const types = iocMap[col];
|
||||
if (!types || types.length === 0) return null;
|
||||
return Array.isArray(types) ? types[0] : (types as any);
|
||||
};
|
||||
|
||||
// Build DataGrid columns from the first row, highlighting IOC columns
|
||||
const columns: GridColDef[] = rows.length > 0
|
||||
? Object.keys(rows[0]).filter(k => k !== '__id').map(k => {
|
||||
const iocType = iocTypeFor(k);
|
||||
const style = iocType ? (IOC_COLORS[iocType] || DEFAULT_IOC_STYLE) : null;
|
||||
return {
|
||||
field: k,
|
||||
headerName: iocType ? `${k} ◆ ${iocType.toUpperCase()}` : k,
|
||||
flex: 1,
|
||||
minWidth: 120,
|
||||
...(style ? {
|
||||
headerClassName: `ioc-header-${iocType}`,
|
||||
cellClassName: `ioc-cell-${iocType}`,
|
||||
} : {}),
|
||||
} as GridColDef;
|
||||
})
|
||||
: [];
|
||||
|
||||
if (loading) return <Box sx={{ p: 4 }}><CircularProgress /></Box>;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 2 }}>
|
||||
<Typography variant="h5">Datasets ({list.length})</Typography>
|
||||
<Stack direction="row" spacing={1}>
|
||||
<Button size="small" startIcon={<RefreshIcon />} onClick={loadList}>Refresh</Button>
|
||||
{selected && (
|
||||
<Button size="small" variant="outlined" onClick={handleEnrich} disabled={enriching}>
|
||||
{enriching ? 'Enriching...' : 'Auto-Enrich IOCs'}
|
||||
</Button>
|
||||
)}
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
{/* Dataset selector */}
|
||||
{list.length > 0 && (
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Stack direction="row" spacing={2} alignItems="center" flexWrap="wrap">
|
||||
<FormControl size="small" sx={{ minWidth: 240 }}>
|
||||
<InputLabel>Dataset</InputLabel>
|
||||
<Select
|
||||
label="Dataset"
|
||||
value={selected?.id || ''}
|
||||
onChange={e => setSelected(list.find(d => d.id === e.target.value) || null)}
|
||||
>
|
||||
{list.map(d => (
|
||||
<MenuItem key={d.id} value={d.id}>{d.name} ({d.row_count} rows)</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
{selected && (
|
||||
<>
|
||||
<Chip label={`${selected.row_count} rows`} size="small" />
|
||||
<Chip label={selected.encoding || 'utf-8'} size="small" variant="outlined" />
|
||||
{selected.source_tool && <Chip label={selected.source_tool} size="small" color="info" variant="outlined" />}
|
||||
{selected.ioc_columns && Object.keys(selected.ioc_columns).length > 0 && (
|
||||
<Chip label={`${Object.keys(selected.ioc_columns).length} IOC columns`} size="small" color="warning" variant="outlined" />
|
||||
)}
|
||||
<IconButton size="small" color="error" onClick={() => handleDelete(selected.id)}>
|
||||
<DeleteIcon fontSize="small" />
|
||||
</IconButton>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
{selected?.time_range_start && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
|
||||
Time range: {selected.time_range_start} — {selected.time_range_end}
|
||||
</Typography>
|
||||
)}
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Data grid */}
|
||||
{selected ? (
|
||||
<Paper sx={{ height: 520 }}>
|
||||
<DataGrid
|
||||
rows={rows}
|
||||
columns={columns}
|
||||
getRowId={r => r.__id}
|
||||
rowCount={rowTotal}
|
||||
loading={rowLoading}
|
||||
paginationMode="server"
|
||||
paginationModel={paginationModel}
|
||||
onPaginationModelChange={setPaginationModel}
|
||||
pageSizeOptions={[25, 50, 100]}
|
||||
density="compact"
|
||||
sx={{
|
||||
border: 'none',
|
||||
'& .MuiDataGrid-cell': { fontSize: '0.8rem' },
|
||||
'& .MuiDataGrid-columnHeader': { fontWeight: 700 },
|
||||
// IOC column highlights
|
||||
...Object.fromEntries(
|
||||
Object.entries(IOC_COLORS).flatMap(([type, c]) => [
|
||||
[`& .ioc-header-${type}`, { backgroundColor: c.header, '& .MuiDataGrid-columnHeaderTitle': { color: c.text, fontWeight: 800 } }],
|
||||
[`& .ioc-cell-${type}`, { backgroundColor: c.bg, borderLeft: `2px solid ${c.text}` }],
|
||||
]),
|
||||
),
|
||||
// Default IOC fallback
|
||||
'& [class*="ioc-header-"]': { backgroundColor: DEFAULT_IOC_STYLE.header },
|
||||
'& [class*="ioc-cell-"]': { backgroundColor: DEFAULT_IOC_STYLE.bg },
|
||||
}}
|
||||
/>
|
||||
</Paper>
|
||||
) : (
|
||||
<Alert severity="info">Upload a CSV to get started.</Alert>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
119
frontend/src/components/EnrichmentPanel.tsx
Normal file
119
frontend/src/components/EnrichmentPanel.tsx
Normal file
@@ -0,0 +1,119 @@
|
||||
/**
|
||||
* EnrichmentPanel — manual IOC lookup + batch enrichment results.
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, TextField, Stack, Button, Chip,
|
||||
Select, MenuItem, FormControl, InputLabel, CircularProgress,
|
||||
Table, TableBody, TableCell, TableContainer, TableHead, TableRow,
|
||||
Alert,
|
||||
} from '@mui/material';
|
||||
import SearchIcon from '@mui/icons-material/Search';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { enrichment, type EnrichmentResult } from '../api/client';
|
||||
|
||||
const IOC_TYPES = ['ip', 'domain', 'hash_md5', 'hash_sha1', 'hash_sha256', 'url'];
|
||||
|
||||
const VERDICT_COLORS: Record<string, 'error' | 'warning' | 'success' | 'default' | 'info'> = {
|
||||
malicious: 'error', suspicious: 'warning', clean: 'success', unknown: 'default', error: 'info',
|
||||
};
|
||||
|
||||
export default function EnrichmentPanel() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [iocValue, setIocValue] = useState('');
|
||||
const [iocType, setIocType] = useState('ip');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [results, setResults] = useState<EnrichmentResult[]>([]);
|
||||
const [overallVerdict, setOverallVerdict] = useState('');
|
||||
const [overallScore, setOverallScore] = useState(0);
|
||||
|
||||
const lookup = useCallback(async () => {
|
||||
if (!iocValue.trim()) return;
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await enrichment.ioc(iocValue.trim(), iocType);
|
||||
setResults(r.results);
|
||||
setOverallVerdict(r.overall_verdict);
|
||||
setOverallScore(r.overall_score);
|
||||
} catch (e: any) {
|
||||
enqueueSnackbar(e.message, { variant: 'error' });
|
||||
}
|
||||
setLoading(false);
|
||||
}, [iocValue, iocType, enqueueSnackbar]);
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Typography variant="h5" gutterBottom>IOC Enrichment</Typography>
|
||||
|
||||
{/* Lookup form */}
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Stack direction="row" spacing={1.5} alignItems="center">
|
||||
<FormControl size="small" sx={{ minWidth: 140 }}>
|
||||
<InputLabel>Type</InputLabel>
|
||||
<Select label="Type" value={iocType} onChange={e => setIocType(e.target.value)}>
|
||||
{IOC_TYPES.map(t => <MenuItem key={t} value={t}>{t}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<TextField
|
||||
size="small" fullWidth label="IOC Value"
|
||||
placeholder="e.g. 1.2.3.4 or evil.com"
|
||||
value={iocValue}
|
||||
onChange={e => setIocValue(e.target.value)}
|
||||
onKeyDown={e => e.key === 'Enter' && lookup()}
|
||||
/>
|
||||
<Button variant="contained" startIcon={loading ? <CircularProgress size={16} /> : <SearchIcon />}
|
||||
onClick={lookup} disabled={loading || !iocValue.trim()}>
|
||||
Lookup
|
||||
</Button>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{/* Overall verdict */}
|
||||
{overallVerdict && (
|
||||
<Alert severity={overallVerdict === 'malicious' ? 'error' : overallVerdict === 'suspicious' ? 'warning' : 'info'} sx={{ mb: 2 }}>
|
||||
Overall verdict: <strong>{overallVerdict}</strong> — Score: {overallScore.toFixed(1)}
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{/* Results table */}
|
||||
{results.length > 0 && (
|
||||
<TableContainer component={Paper}>
|
||||
<Table size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Source</TableCell>
|
||||
<TableCell>Verdict</TableCell>
|
||||
<TableCell>Score</TableCell>
|
||||
<TableCell>Country</TableCell>
|
||||
<TableCell>Org</TableCell>
|
||||
<TableCell>Tags</TableCell>
|
||||
<TableCell>Latency</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{results.map((r, i) => (
|
||||
<TableRow key={i}>
|
||||
<TableCell>{r.source}</TableCell>
|
||||
<TableCell>
|
||||
<Chip label={r.verdict} size="small"
|
||||
color={VERDICT_COLORS[r.verdict] || 'default'} variant="outlined" />
|
||||
</TableCell>
|
||||
<TableCell>{r.score.toFixed(1)}</TableCell>
|
||||
<TableCell>{r.country || '—'}</TableCell>
|
||||
<TableCell>{r.org || '—'}</TableCell>
|
||||
<TableCell>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap">
|
||||
{r.tags.slice(0, 5).map((t, j) => <Chip key={j} label={t} size="small" />)}
|
||||
</Stack>
|
||||
</TableCell>
|
||||
<TableCell>{r.latency_ms}ms</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
223
frontend/src/components/FileUpload.tsx
Normal file
223
frontend/src/components/FileUpload.tsx
Normal file
@@ -0,0 +1,223 @@
|
||||
/**
|
||||
* FileUpload — multi-file drag-and-drop CSV upload with per-file progress bars.
|
||||
*/
|
||||
|
||||
import React, { useState, useCallback, useRef } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Chip, LinearProgress,
|
||||
Select, MenuItem, FormControl, InputLabel, IconButton, Tooltip,
|
||||
} from '@mui/material';
|
||||
import CloudUploadIcon from '@mui/icons-material/CloudUpload';
|
||||
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||
import ErrorIcon from '@mui/icons-material/Error';
|
||||
import ClearIcon from '@mui/icons-material/Clear';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { datasets, hunts, type UploadResult, type Hunt } from '../api/client';
|
||||
|
||||
interface FileJob {
|
||||
file: File;
|
||||
status: 'queued' | 'uploading' | 'done' | 'error';
|
||||
progress: number; // 0–100
|
||||
result?: UploadResult;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export default function FileUpload() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [dragOver, setDragOver] = useState(false);
|
||||
const [jobs, setJobs] = useState<FileJob[]>([]);
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [huntId, setHuntId] = useState('');
|
||||
const fileRef = useRef<HTMLInputElement>(null);
|
||||
const busyRef = useRef(false);
|
||||
|
||||
React.useEffect(() => {
|
||||
hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
// Process the queue sequentially
|
||||
const processQueue = useCallback(async (queue: FileJob[]) => {
|
||||
if (busyRef.current) return;
|
||||
busyRef.current = true;
|
||||
|
||||
for (let i = 0; i < queue.length; i++) {
|
||||
if (queue[i].status !== 'queued') continue;
|
||||
|
||||
// Mark uploading
|
||||
setJobs(prev => prev.map((j, idx) =>
|
||||
idx === i ? { ...j, status: 'uploading' as const, progress: 0 } : j
|
||||
));
|
||||
|
||||
try {
|
||||
const result = await datasets.uploadWithProgress(
|
||||
queue[i].file,
|
||||
huntId || undefined,
|
||||
(pct) => {
|
||||
setJobs(prev => prev.map((j, idx) =>
|
||||
idx === i ? { ...j, progress: pct } : j
|
||||
));
|
||||
},
|
||||
);
|
||||
setJobs(prev => prev.map((j, idx) =>
|
||||
idx === i ? { ...j, status: 'done' as const, progress: 100, result } : j
|
||||
));
|
||||
enqueueSnackbar(
|
||||
`${queue[i].file.name}: ${result.row_count} rows, ${result.columns.length} columns`,
|
||||
{ variant: 'success' },
|
||||
);
|
||||
} catch (e: any) {
|
||||
setJobs(prev => prev.map((j, idx) =>
|
||||
idx === i ? { ...j, status: 'error' as const, error: e.message } : j
|
||||
));
|
||||
enqueueSnackbar(`${queue[i].file.name}: ${e.message}`, { variant: 'error' });
|
||||
}
|
||||
}
|
||||
busyRef.current = false;
|
||||
}, [huntId, enqueueSnackbar]);
|
||||
|
||||
const enqueueFiles = useCallback((files: FileList | File[]) => {
|
||||
const newJobs: FileJob[] = Array.from(files).map(file => ({
|
||||
file, status: 'queued' as const, progress: 0,
|
||||
}));
|
||||
setJobs(prev => {
|
||||
const merged = [...prev, ...newJobs];
|
||||
// kick off processing with the full merged list
|
||||
setTimeout(() => processQueue(merged), 0);
|
||||
return merged;
|
||||
});
|
||||
}, [processQueue]);
|
||||
|
||||
const onDrop = useCallback((e: React.DragEvent) => {
|
||||
e.preventDefault(); setDragOver(false);
|
||||
if (e.dataTransfer.files.length > 0) enqueueFiles(e.dataTransfer.files);
|
||||
}, [enqueueFiles]);
|
||||
|
||||
const onFileChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (e.target.files && e.target.files.length > 0) {
|
||||
enqueueFiles(e.target.files);
|
||||
e.target.value = ''; // reset so same file can be re-selected
|
||||
}
|
||||
}, [enqueueFiles]);
|
||||
|
||||
const clearCompleted = useCallback(() => {
|
||||
setJobs(prev => prev.filter(j => j.status === 'queued' || j.status === 'uploading'));
|
||||
}, []);
|
||||
|
||||
const overallDone = jobs.filter(j => j.status === 'done').length;
|
||||
const overallErr = jobs.filter(j => j.status === 'error').length;
|
||||
const overallTotal = jobs.length;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Typography variant="h5" gutterBottom>Upload Datasets</Typography>
|
||||
|
||||
{/* Hunt association */}
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<FormControl size="small" sx={{ minWidth: 300 }}>
|
||||
<InputLabel>Associate with Hunt (optional)</InputLabel>
|
||||
<Select label="Associate with Hunt (optional)" value={huntId}
|
||||
onChange={e => setHuntId(e.target.value)}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Paper>
|
||||
|
||||
{/* Drop zone */}
|
||||
<Paper
|
||||
sx={{
|
||||
p: 6, textAlign: 'center', cursor: 'pointer',
|
||||
border: '2px dashed',
|
||||
borderColor: dragOver ? 'primary.main' : 'divider',
|
||||
bgcolor: dragOver ? 'action.hover' : 'background.paper',
|
||||
transition: 'all 0.2s',
|
||||
}}
|
||||
onDragOver={e => { e.preventDefault(); setDragOver(true); }}
|
||||
onDragLeave={() => setDragOver(false)}
|
||||
onDrop={onDrop}
|
||||
onClick={() => fileRef.current?.click()}
|
||||
>
|
||||
<input ref={fileRef} type="file" accept=".csv,.tsv,.txt" hidden multiple onChange={onFileChange} />
|
||||
<CloudUploadIcon sx={{ fontSize: 64, color: 'text.secondary', mb: 1 }} />
|
||||
<Typography variant="h6" color="text.secondary">
|
||||
Drag & drop CSV / TSV files here
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
or click to browse — multiple files supported — max 100 MB each
|
||||
</Typography>
|
||||
</Paper>
|
||||
|
||||
{/* Overall progress summary */}
|
||||
{overallTotal > 0 && (
|
||||
<Stack direction="row" alignItems="center" spacing={1} sx={{ mt: 2 }}>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
{overallDone + overallErr} / {overallTotal} files processed
|
||||
{overallErr > 0 && ` (${overallErr} failed)`}
|
||||
</Typography>
|
||||
<Box sx={{ flexGrow: 1 }} />
|
||||
{overallDone + overallErr === overallTotal && overallTotal > 0 && (
|
||||
<Tooltip title="Clear completed">
|
||||
<IconButton size="small" onClick={clearCompleted}><ClearIcon fontSize="small" /></IconButton>
|
||||
</Tooltip>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
|
||||
{/* Per-file progress list */}
|
||||
{jobs.map((job, i) => (
|
||||
<Paper key={`${job.file.name}-${i}`} sx={{ p: 2, mt: 1 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5}>
|
||||
{job.status === 'done' && <CheckCircleIcon color="success" fontSize="small" />}
|
||||
{job.status === 'error' && <ErrorIcon color="error" fontSize="small" />}
|
||||
{(job.status === 'queued' || job.status === 'uploading') && (
|
||||
<Box sx={{ width: 20, height: 20 }} />
|
||||
)}
|
||||
<Box sx={{ minWidth: 0, flexGrow: 1 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1}>
|
||||
<Typography variant="body2" noWrap sx={{ fontWeight: 600 }}>
|
||||
{job.file.name}
|
||||
</Typography>
|
||||
<Typography variant="caption" color="text.secondary">
|
||||
({(job.file.size / 1024 / 1024).toFixed(1)} MB)
|
||||
</Typography>
|
||||
{job.status === 'queued' && (
|
||||
<Chip label="Queued" size="small" variant="outlined" />
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
{/* Progress bar */}
|
||||
{job.status === 'uploading' && (
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1, mt: 0.5 }}>
|
||||
<LinearProgress
|
||||
variant="determinate" value={job.progress}
|
||||
sx={{ flexGrow: 1, height: 8, borderRadius: 4 }}
|
||||
/>
|
||||
<Typography variant="caption" sx={{ minWidth: 36 }}>
|
||||
{job.progress}%
|
||||
</Typography>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{job.status === 'error' && (
|
||||
<Typography variant="caption" color="error">{job.error}</Typography>
|
||||
)}
|
||||
|
||||
{/* Success details */}
|
||||
{job.status === 'done' && job.result && (
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" sx={{ mt: 0.5 }}>
|
||||
<Chip label={`${job.result.row_count} rows`} size="small" color="primary" />
|
||||
<Chip label={`${job.result.columns.length} cols`} size="small" />
|
||||
{Object.keys(job.result.ioc_columns).length > 0 && (
|
||||
<Chip label={`${Object.keys(job.result.ioc_columns).length} IOC cols`}
|
||||
size="small" color="warning" />
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
</Box>
|
||||
</Stack>
|
||||
</Paper>
|
||||
))}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
156
frontend/src/components/HuntManager.tsx
Normal file
156
frontend/src/components/HuntManager.tsx
Normal file
@@ -0,0 +1,156 @@
|
||||
/**
|
||||
* HuntManager — create, list, and manage hunts.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Button, TextField, Dialog, DialogTitle,
|
||||
DialogContent, DialogActions, Chip, Stack, IconButton, Table,
|
||||
TableBody, TableCell, TableContainer, TableHead, TableRow,
|
||||
CircularProgress, Select, MenuItem, FormControl, InputLabel,
|
||||
} from '@mui/material';
|
||||
import AddIcon from '@mui/icons-material/Add';
|
||||
import EditIcon from '@mui/icons-material/Edit';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import DownloadIcon from '@mui/icons-material/Download';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { hunts, reports, type Hunt } from '../api/client';
|
||||
|
||||
const STATUS_COLORS: Record<string, 'success' | 'default' | 'warning'> = {
|
||||
active: 'success', closed: 'default', archived: 'warning',
|
||||
};
|
||||
|
||||
export default function HuntManager() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [list, setList] = useState<Hunt[]>([]);
|
||||
const [total, setTotal] = useState(0);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [dlgOpen, setDlgOpen] = useState(false);
|
||||
const [editHunt, setEditHunt] = useState<Hunt | null>(null);
|
||||
const [form, setForm] = useState({ name: '', description: '', status: 'active' });
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await hunts.list(0, 100);
|
||||
setList(r.hunts); setTotal(r.total);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [enqueueSnackbar]);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const openCreate = () => { setEditHunt(null); setForm({ name: '', description: '', status: 'active' }); setDlgOpen(true); };
|
||||
const openEdit = (h: Hunt) => { setEditHunt(h); setForm({ name: h.name, description: h.description || '', status: h.status }); setDlgOpen(true); };
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
if (editHunt) {
|
||||
await hunts.update(editHunt.id, form);
|
||||
enqueueSnackbar('Hunt updated', { variant: 'success' });
|
||||
} else {
|
||||
await hunts.create(form.name, form.description || undefined);
|
||||
enqueueSnackbar('Hunt created', { variant: 'success' });
|
||||
}
|
||||
setDlgOpen(false); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
if (!window.confirm('Delete this hunt?')) return;
|
||||
try {
|
||||
await hunts.delete(id);
|
||||
enqueueSnackbar('Hunt deleted', { variant: 'info' }); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
const handleExport = async (id: string, fmt: 'json' | 'html' | 'csv') => {
|
||||
try {
|
||||
const data = fmt === 'json' ? JSON.stringify(await reports.json(id), null, 2)
|
||||
: fmt === 'html' ? await reports.html(id)
|
||||
: await reports.csv(id);
|
||||
const blob = new Blob([data], { type: fmt === 'json' ? 'application/json' : fmt === 'html' ? 'text/html' : 'text/csv' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a'); a.href = url; a.download = `hunt_${id}.${fmt}`; a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
enqueueSnackbar(`Report exported as ${fmt.toUpperCase()}`, { variant: 'success' });
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
if (loading) return <Box sx={{ p: 4 }}><CircularProgress /></Box>;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 2 }}>
|
||||
<Typography variant="h5">Hunts ({total})</Typography>
|
||||
<Button variant="contained" startIcon={<AddIcon />} onClick={openCreate}>New Hunt</Button>
|
||||
</Stack>
|
||||
|
||||
<TableContainer component={Paper}>
|
||||
<Table size="small">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableCell>Name</TableCell>
|
||||
<TableCell>Status</TableCell>
|
||||
<TableCell>Datasets</TableCell>
|
||||
<TableCell>Hypotheses</TableCell>
|
||||
<TableCell>Created</TableCell>
|
||||
<TableCell align="right">Actions</TableCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{list.map(h => (
|
||||
<TableRow key={h.id} hover>
|
||||
<TableCell>
|
||||
<Typography variant="body2" fontWeight={600}>{h.name}</Typography>
|
||||
{h.description && <Typography variant="caption" color="text.secondary">{h.description}</Typography>}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Chip label={h.status} size="small" color={STATUS_COLORS[h.status] || 'default'} variant="outlined" />
|
||||
</TableCell>
|
||||
<TableCell>{h.dataset_count}</TableCell>
|
||||
<TableCell>{h.hypothesis_count}</TableCell>
|
||||
<TableCell><Typography variant="caption">{new Date(h.created_at).toLocaleDateString()}</Typography></TableCell>
|
||||
<TableCell align="right">
|
||||
<IconButton size="small" onClick={() => openEdit(h)}><EditIcon fontSize="small" /></IconButton>
|
||||
<IconButton size="small" onClick={() => handleExport(h.id, 'html')} title="Export HTML"><DownloadIcon fontSize="small" /></IconButton>
|
||||
<IconButton size="small" color="error" onClick={() => handleDelete(h.id)}><DeleteIcon fontSize="small" /></IconButton>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
{list.length === 0 && (
|
||||
<TableRow><TableCell colSpan={6} align="center"><Typography color="text.secondary">No hunts yet</Typography></TableCell></TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
|
||||
{/* Create / Edit dialog */}
|
||||
<Dialog open={dlgOpen} onClose={() => setDlgOpen(false)} maxWidth="sm" fullWidth>
|
||||
<DialogTitle>{editHunt ? 'Edit Hunt' : 'New Hunt'}</DialogTitle>
|
||||
<DialogContent>
|
||||
<Stack spacing={2} sx={{ mt: 1 }}>
|
||||
<TextField label="Name" fullWidth value={form.name} onChange={e => setForm(f => ({ ...f, name: e.target.value }))} />
|
||||
<TextField label="Description" fullWidth multiline rows={3} value={form.description} onChange={e => setForm(f => ({ ...f, description: e.target.value }))} />
|
||||
{editHunt && (
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Status</InputLabel>
|
||||
<Select label="Status" value={form.status} onChange={e => setForm(f => ({ ...f, status: e.target.value }))}>
|
||||
<MenuItem value="active">Active</MenuItem>
|
||||
<MenuItem value="closed">Closed</MenuItem>
|
||||
<MenuItem value="archived">Archived</MenuItem>
|
||||
</Select>
|
||||
</FormControl>
|
||||
)}
|
||||
</Stack>
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={() => setDlgOpen(false)}>Cancel</Button>
|
||||
<Button variant="contained" onClick={handleSave} disabled={!form.name.trim()}>
|
||||
{editHunt ? 'Save' : 'Create'}
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
202
frontend/src/components/HypothesisTracker.tsx
Normal file
202
frontend/src/components/HypothesisTracker.tsx
Normal file
@@ -0,0 +1,202 @@
|
||||
/**
|
||||
* HypothesisTracker — create, track status, link MITRE techniques.
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useCallback } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Chip, Button, TextField,
|
||||
Select, MenuItem, FormControl, InputLabel, CircularProgress,
|
||||
IconButton, Dialog, DialogTitle, DialogContent, DialogActions,
|
||||
Card, CardContent, CardActions, Grid,
|
||||
} from '@mui/material';
|
||||
import AddIcon from '@mui/icons-material/Add';
|
||||
import EditIcon from '@mui/icons-material/Edit';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import { useSnackbar } from 'notistack';
|
||||
import { hypotheses, hunts, type HypothesisData, type Hunt } from '../api/client';
|
||||
|
||||
const STATUSES = ['draft', 'active', 'confirmed', 'rejected'];
|
||||
const STATUS_COLORS: Record<string, 'default' | 'info' | 'success' | 'error'> = {
|
||||
draft: 'default', active: 'info', confirmed: 'success', rejected: 'error',
|
||||
};
|
||||
|
||||
export default function HypothesisTracker() {
|
||||
const { enqueueSnackbar } = useSnackbar();
|
||||
const [list, setList] = useState<HypothesisData[]>([]);
|
||||
const [total, setTotal] = useState(0);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [filterHunt, setFilterHunt] = useState('');
|
||||
const [filterStatus, setFilterStatus] = useState('');
|
||||
const [dlgOpen, setDlgOpen] = useState(false);
|
||||
const [editItem, setEditItem] = useState<HypothesisData | null>(null);
|
||||
const [form, setForm] = useState({
|
||||
title: '', description: '', mitre_technique: '', status: 'draft',
|
||||
hunt_id: '', evidence_notes: '',
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
hunts.list(0, 100).then(r => setHuntList(r.hunts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
const load = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await hypotheses.list({
|
||||
hunt_id: filterHunt || undefined,
|
||||
status: filterStatus || undefined,
|
||||
limit: 100,
|
||||
});
|
||||
setList(r.hypotheses); setTotal(r.total);
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
setLoading(false);
|
||||
}, [filterHunt, filterStatus, enqueueSnackbar]);
|
||||
|
||||
useEffect(() => { load(); }, [load]);
|
||||
|
||||
const openCreate = () => {
|
||||
setEditItem(null);
|
||||
setForm({ title: '', description: '', mitre_technique: '', status: 'draft', hunt_id: '', evidence_notes: '' });
|
||||
setDlgOpen(true);
|
||||
};
|
||||
|
||||
const openEdit = (h: HypothesisData) => {
|
||||
setEditItem(h);
|
||||
setForm({
|
||||
title: h.title, description: h.description || '', mitre_technique: h.mitre_technique || '',
|
||||
status: h.status, hunt_id: h.hunt_id || '', evidence_notes: h.evidence_notes || '',
|
||||
});
|
||||
setDlgOpen(true);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
if (editItem) {
|
||||
await hypotheses.update(editItem.id, {
|
||||
title: form.title, description: form.description || undefined,
|
||||
mitre_technique: form.mitre_technique || undefined, status: form.status,
|
||||
evidence_notes: form.evidence_notes || undefined,
|
||||
});
|
||||
enqueueSnackbar('Hypothesis updated', { variant: 'success' });
|
||||
} else {
|
||||
await hypotheses.create({
|
||||
title: form.title, description: form.description || undefined,
|
||||
mitre_technique: form.mitre_technique || undefined,
|
||||
hunt_id: form.hunt_id || undefined, status: form.status,
|
||||
});
|
||||
enqueueSnackbar('Hypothesis created', { variant: 'success' });
|
||||
}
|
||||
setDlgOpen(false); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
const handleDelete = async (id: string) => {
|
||||
if (!window.confirm('Delete this hypothesis?')) return;
|
||||
try {
|
||||
await hypotheses.delete(id);
|
||||
enqueueSnackbar('Deleted', { variant: 'info' }); load();
|
||||
} catch (e: any) { enqueueSnackbar(e.message, { variant: 'error' }); }
|
||||
};
|
||||
|
||||
if (loading) return <Box sx={{ p: 4 }}><CircularProgress /></Box>;
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 2 }}>
|
||||
<Typography variant="h5">Hypotheses ({total})</Typography>
|
||||
<Button variant="contained" startIcon={<AddIcon />} onClick={openCreate}>New Hypothesis</Button>
|
||||
</Stack>
|
||||
|
||||
<Paper sx={{ p: 1.5, mb: 2 }}>
|
||||
<Stack direction="row" spacing={1.5}>
|
||||
<FormControl size="small" sx={{ minWidth: 150 }}>
|
||||
<InputLabel>Hunt</InputLabel>
|
||||
<Select label="Hunt" value={filterHunt} onChange={e => setFilterHunt(e.target.value)}>
|
||||
<MenuItem value="">All</MenuItem>
|
||||
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<FormControl size="small" sx={{ minWidth: 120 }}>
|
||||
<InputLabel>Status</InputLabel>
|
||||
<Select label="Status" value={filterStatus} onChange={e => setFilterStatus(e.target.value)}>
|
||||
<MenuItem value="">All</MenuItem>
|
||||
{STATUSES.map(s => <MenuItem key={s} value={s}>{s}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
<Grid container spacing={2}>
|
||||
{list.map(h => (
|
||||
<Grid size={{ xs: 12, sm: 6, md: 4 }} key={h.id}>
|
||||
<Card variant="outlined">
|
||||
<CardContent>
|
||||
<Stack direction="row" alignItems="center" spacing={1} sx={{ mb: 1 }}>
|
||||
<Chip label={h.status} size="small" color={STATUS_COLORS[h.status] || 'default'} />
|
||||
{h.mitre_technique && <Chip label={h.mitre_technique} size="small" variant="outlined" color="info" />}
|
||||
</Stack>
|
||||
<Typography variant="subtitle1" fontWeight={600}>{h.title}</Typography>
|
||||
{h.description && <Typography variant="body2" color="text.secondary" sx={{ mt: 0.5 }}>{h.description}</Typography>}
|
||||
{h.evidence_notes && (
|
||||
<Typography variant="caption" color="text.secondary" sx={{ mt: 1, display: 'block' }}>
|
||||
Evidence: {h.evidence_notes}
|
||||
</Typography>
|
||||
)}
|
||||
</CardContent>
|
||||
<CardActions>
|
||||
<IconButton size="small" onClick={() => openEdit(h)}><EditIcon fontSize="small" /></IconButton>
|
||||
<IconButton size="small" color="error" onClick={() => handleDelete(h.id)}><DeleteIcon fontSize="small" /></IconButton>
|
||||
</CardActions>
|
||||
</Card>
|
||||
</Grid>
|
||||
))}
|
||||
{list.length === 0 && (
|
||||
<Grid size={12}>
|
||||
<Paper sx={{ p: 4, textAlign: 'center' }}>
|
||||
<Typography color="text.secondary" gutterBottom>No hypotheses yet. Create one to track your investigation.</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Hypotheses let you document what you think is happening (e.g. "Attacker used T1059.001 PowerShell
|
||||
to exfiltrate data"), link them to a hunt and MITRE ATT&CK technique, then update their status
|
||||
as evidence confirms or rejects them.
|
||||
</Typography>
|
||||
</Paper>
|
||||
</Grid>
|
||||
)}
|
||||
</Grid>
|
||||
|
||||
{/* Dialog */}
|
||||
<Dialog open={dlgOpen} onClose={() => setDlgOpen(false)} maxWidth="sm" fullWidth>
|
||||
<DialogTitle>{editItem ? 'Edit Hypothesis' : 'New Hypothesis'}</DialogTitle>
|
||||
<DialogContent>
|
||||
<Stack spacing={2} sx={{ mt: 1 }}>
|
||||
<TextField label="Title" fullWidth value={form.title} onChange={e => setForm(f => ({ ...f, title: e.target.value }))} />
|
||||
<TextField label="Description" fullWidth multiline rows={3} value={form.description} onChange={e => setForm(f => ({ ...f, description: e.target.value }))} />
|
||||
<TextField label="MITRE Technique" fullWidth placeholder="e.g. T1059.001" value={form.mitre_technique} onChange={e => setForm(f => ({ ...f, mitre_technique: e.target.value }))} />
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Status</InputLabel>
|
||||
<Select label="Status" value={form.status} onChange={e => setForm(f => ({ ...f, status: e.target.value }))}>
|
||||
{STATUSES.map(s => <MenuItem key={s} value={s}>{s}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
{!editItem && (
|
||||
<FormControl fullWidth>
|
||||
<InputLabel>Hunt</InputLabel>
|
||||
<Select label="Hunt" value={form.hunt_id} onChange={e => setForm(f => ({ ...f, hunt_id: e.target.value }))}>
|
||||
<MenuItem value="">None</MenuItem>
|
||||
{huntList.map(h => <MenuItem key={h.id} value={h.id}>{h.name}</MenuItem>)}
|
||||
</Select>
|
||||
</FormControl>
|
||||
)}
|
||||
<TextField label="Evidence Notes" fullWidth multiline rows={2} value={form.evidence_notes} onChange={e => setForm(f => ({ ...f, evidence_notes: e.target.value }))} />
|
||||
</Stack>
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={() => setDlgOpen(false)}>Cancel</Button>
|
||||
<Button variant="contained" onClick={handleSave} disabled={!form.title.trim()}>
|
||||
{editItem ? 'Save' : 'Create'}
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
777
frontend/src/components/NetworkMap.tsx
Normal file
777
frontend/src/components/NetworkMap.tsx
Normal file
@@ -0,0 +1,777 @@
|
||||
/**
|
||||
* NetworkMap — interactive hunt-scoped force-directed network graph.
|
||||
*
|
||||
* • Select a hunt → loads only that hunt's datasets
|
||||
* • Nodes = unique IPs / hostnames / domains pulled from IOC columns
|
||||
* • Edges = "seen together in the same row" co-occurrence
|
||||
* • Click a node → popover showing hostname, IP, OS, dataset sources, connections
|
||||
* • Responsive canvas with ResizeObserver
|
||||
* • Zero extra npm dependencies
|
||||
*/
|
||||
|
||||
import React, { useEffect, useState, useRef, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
Box, Typography, Paper, Stack, Alert, Chip, Button, TextField,
|
||||
LinearProgress, FormControl, InputLabel, Select, MenuItem,
|
||||
Popover, Divider, IconButton,
|
||||
} from '@mui/material';
|
||||
import RefreshIcon from '@mui/icons-material/Refresh';
|
||||
import CloseIcon from '@mui/icons-material/Close';
|
||||
import ZoomInIcon from '@mui/icons-material/ZoomIn';
|
||||
import ZoomOutIcon from '@mui/icons-material/ZoomOut';
|
||||
import CenterFocusStrongIcon from '@mui/icons-material/CenterFocusStrong';
|
||||
import { datasets, hunts, type Hunt, type DatasetSummary } from '../api/client';
|
||||
|
||||
// ── Graph primitives ─────────────────────────────────────────────────
|
||||
|
||||
type NodeType = 'ip' | 'hostname' | 'domain' | 'url';
|
||||
|
||||
interface NodeMeta {
|
||||
hostnames: Set<string>;
|
||||
ips: Set<string>;
|
||||
os: Set<string>;
|
||||
datasets: Set<string>;
|
||||
type: NodeType;
|
||||
}
|
||||
|
||||
interface GNode {
|
||||
id: string; label: string; x: number; y: number;
|
||||
vx: number; vy: number; radius: number; color: string; count: number;
|
||||
meta: { hostnames: string[]; ips: string[]; os: string[]; datasets: string[]; type: NodeType };
|
||||
}
|
||||
interface GEdge { source: string; target: string; weight: number }
|
||||
interface Graph { nodes: GNode[]; edges: GEdge[] }
|
||||
|
||||
const TYPE_COLORS: Record<NodeType, string> = {
|
||||
ip: '#3b82f6', hostname: '#22c55e', domain: '#eab308', url: '#8b5cf6',
|
||||
};
|
||||
|
||||
// ── Helpers: find context columns from dataset schema ────────────────
|
||||
|
||||
/** Best-effort detection of hostname, IP, and OS columns from raw column names + normalized mapping. */
|
||||
function findContextColumns(ds: DatasetSummary) {
|
||||
const norm = ds.normalized_columns || {};
|
||||
const schema = ds.column_schema || {};
|
||||
const rawCols = Object.keys(schema).length > 0 ? Object.keys(schema) : Object.keys(norm);
|
||||
|
||||
const hostCols: string[] = [];
|
||||
const ipCols: string[] = [];
|
||||
const osCols: string[] = [];
|
||||
|
||||
for (const raw of rawCols) {
|
||||
const canonical = norm[raw] || '';
|
||||
const lower = raw.toLowerCase();
|
||||
// Hostname columns
|
||||
if (canonical === 'hostname' || /^(hostname|host|fqdn|computer_?name|system_?name|machinename)$/i.test(lower)) {
|
||||
hostCols.push(raw);
|
||||
}
|
||||
// IP columns
|
||||
if (['src_ip', 'dst_ip', 'ip_address'].includes(canonical) || /^(ip|ip_?address|src_?ip|dst_?ip|source_?ip|dest_?ip)$/i.test(lower)) {
|
||||
ipCols.push(raw);
|
||||
}
|
||||
// OS columns (best-effort — raw name scan + normalized canonical)
|
||||
if (canonical === 'os' || /^(os|operating_?system|os_?version|os_?name|platform|os_?type)$/i.test(lower)) {
|
||||
osCols.push(raw);
|
||||
}
|
||||
}
|
||||
return { hostCols, ipCols, osCols };
|
||||
}
|
||||
|
||||
function cleanVal(v: any): string {
|
||||
const s = (v ?? '').toString().trim();
|
||||
return (s && s !== '-' && s !== '0.0.0.0' && s !== '::') ? s : '';
|
||||
}
|
||||
|
||||
// ── Build graph with per-node metadata ───────────────────────────────
|
||||
|
||||
interface RowBatch {
|
||||
rows: Record<string, any>[];
|
||||
iocColumns: Record<string, any>;
|
||||
dsName: string;
|
||||
ds: DatasetSummary;
|
||||
}
|
||||
|
||||
function buildGraph(allBatches: RowBatch[], canvasW: number, canvasH: number): Graph {
|
||||
const countMap = new Map<string, number>();
|
||||
const edgeMap = new Map<string, number>();
|
||||
const metaMap = new Map<string, NodeMeta>();
|
||||
|
||||
const getOrCreateMeta = (id: string, type: NodeType): NodeMeta => {
|
||||
let m = metaMap.get(id);
|
||||
if (!m) { m = { hostnames: new Set(), ips: new Set(), os: new Set(), datasets: new Set(), type }; metaMap.set(id, m); }
|
||||
return m;
|
||||
};
|
||||
|
||||
for (const { rows, iocColumns, dsName, ds } of allBatches) {
|
||||
// IOC columns that produce graph nodes
|
||||
const iocEntries = Object.entries(iocColumns).filter(([, t]) => {
|
||||
const typ = Array.isArray(t) ? t[0] : t;
|
||||
return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url';
|
||||
}).map(([col, t]) => {
|
||||
const typ = (Array.isArray(t) ? t[0] : t) as NodeType;
|
||||
return { col, typ };
|
||||
});
|
||||
|
||||
if (iocEntries.length === 0) continue;
|
||||
|
||||
// Context columns for enrichment
|
||||
const ctx = findContextColumns(ds);
|
||||
|
||||
for (const row of rows) {
|
||||
// Collect IOC values for this row (nodes + edges)
|
||||
const vals: { v: string; typ: NodeType }[] = [];
|
||||
for (const { col, typ } of iocEntries) {
|
||||
const v = cleanVal(row[col]);
|
||||
if (v) vals.push({ v, typ });
|
||||
}
|
||||
const unique = [...new Map(vals.map(x => [x.v, x])).values()];
|
||||
|
||||
// Count occurrences
|
||||
for (const { v } of unique) countMap.set(v, (countMap.get(v) ?? 0) + 1);
|
||||
|
||||
// Create edges (co-occurrence)
|
||||
for (let i = 0; i < unique.length; i++) {
|
||||
for (let j = i + 1; j < unique.length; j++) {
|
||||
const key = [unique[i].v, unique[j].v].sort().join('||');
|
||||
edgeMap.set(key, (edgeMap.get(key) ?? 0) + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract context values from this row
|
||||
const rowHosts = ctx.hostCols.map(c => cleanVal(row[c])).filter(Boolean);
|
||||
const rowIps = ctx.ipCols.map(c => cleanVal(row[c])).filter(Boolean);
|
||||
const rowOs = ctx.osCols.map(c => cleanVal(row[c])).filter(Boolean);
|
||||
|
||||
// Attach context to each node in this row
|
||||
for (const { v, typ } of unique) {
|
||||
const meta = getOrCreateMeta(v, typ);
|
||||
meta.datasets.add(dsName);
|
||||
for (const h of rowHosts) meta.hostnames.add(h);
|
||||
for (const ip of rowIps) meta.ips.add(ip);
|
||||
for (const o of rowOs) meta.os.add(o);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nodes: GNode[] = [...countMap.entries()].map(([id, count]) => {
|
||||
const raw = metaMap.get(id);
|
||||
const type: NodeType = raw?.type || 'ip';
|
||||
return {
|
||||
id, label: id, count,
|
||||
x: canvasW / 2 + (Math.random() - 0.5) * canvasW * 0.75,
|
||||
y: canvasH / 2 + (Math.random() - 0.5) * canvasH * 0.65,
|
||||
vx: 0, vy: 0,
|
||||
radius: Math.max(5, Math.min(18, 4 + Math.sqrt(count) * 1.6)),
|
||||
color: TYPE_COLORS[type],
|
||||
meta: {
|
||||
hostnames: [...(raw?.hostnames ?? [])],
|
||||
ips: [...(raw?.ips ?? [])],
|
||||
os: [...(raw?.os ?? [])],
|
||||
datasets: [...(raw?.datasets ?? [])],
|
||||
type,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const edges: GEdge[] = [...edgeMap.entries()].map(([key, weight]) => {
|
||||
const [source, target] = key.split('||');
|
||||
return { source, target, weight };
|
||||
});
|
||||
|
||||
return { nodes, edges };
|
||||
}
|
||||
|
||||
// ── Force simulation ─────────────────────────────────────────────────
|
||||
|
||||
function simulate(graph: Graph, cx: number, cy: number, steps = 120) {
|
||||
const { nodes, edges } = graph;
|
||||
const nodeMap = new Map(nodes.map(n => [n.id, n]));
|
||||
const k = 80;
|
||||
const repulsion = 6000;
|
||||
const damping = 0.85;
|
||||
|
||||
for (let step = 0; step < steps; step++) {
|
||||
for (let i = 0; i < nodes.length; i++) {
|
||||
for (let j = i + 1; j < nodes.length; j++) {
|
||||
const a = nodes[i], b = nodes[j];
|
||||
const dx = b.x - a.x, dy = b.y - a.y;
|
||||
const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy));
|
||||
const force = repulsion / (dist * dist);
|
||||
const fx = (dx / dist) * force, fy = (dy / dist) * force;
|
||||
a.vx -= fx; a.vy -= fy;
|
||||
b.vx += fx; b.vy += fy;
|
||||
}
|
||||
}
|
||||
for (const e of edges) {
|
||||
const a = nodeMap.get(e.source), b = nodeMap.get(e.target);
|
||||
if (!a || !b) continue;
|
||||
const dx = b.x - a.x, dy = b.y - a.y;
|
||||
const dist = Math.max(1, Math.sqrt(dx * dx + dy * dy));
|
||||
const force = (dist - k) * 0.05;
|
||||
const fx = (dx / dist) * force, fy = (dy / dist) * force;
|
||||
a.vx += fx; a.vy += fy;
|
||||
b.vx -= fx; b.vy -= fy;
|
||||
}
|
||||
for (const n of nodes) {
|
||||
n.vx += (cx - n.x) * 0.001;
|
||||
n.vy += (cy - n.y) * 0.001;
|
||||
n.vx *= damping; n.vy *= damping;
|
||||
n.x += n.vx; n.y += n.vy;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Viewport (zoom / pan) ────────────────────────────────────────────
|
||||
|
||||
interface Viewport { x: number; y: number; scale: number }
|
||||
|
||||
const MIN_ZOOM = 0.1;
|
||||
const MAX_ZOOM = 8;
|
||||
|
||||
// ── Canvas renderer ──────────────────────────────────────────────────
|
||||
|
||||
function drawGraph(
|
||||
ctx: CanvasRenderingContext2D, graph: Graph,
|
||||
hovered: string | null, selected: string | null, search: string,
|
||||
vp: Viewport,
|
||||
) {
|
||||
const { nodes, edges } = graph;
|
||||
const nodeMap = new Map(nodes.map(n => [n.id, n]));
|
||||
const matchSet = new Set<string>();
|
||||
if (search) {
|
||||
const lc = search.toLowerCase();
|
||||
for (const n of nodes) if (n.label.toLowerCase().includes(lc)) matchSet.add(n.id);
|
||||
}
|
||||
|
||||
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);
|
||||
ctx.save();
|
||||
ctx.translate(vp.x, vp.y);
|
||||
ctx.scale(vp.scale, vp.scale);
|
||||
|
||||
// Edges
|
||||
for (const e of edges) {
|
||||
const a = nodeMap.get(e.source), b = nodeMap.get(e.target);
|
||||
if (!a || !b) continue;
|
||||
const isActive = (hovered && (e.source === hovered || e.target === hovered))
|
||||
|| (selected && (e.source === selected || e.target === selected));
|
||||
ctx.beginPath();
|
||||
ctx.strokeStyle = isActive ? 'rgba(96,165,250,0.7)' : 'rgba(100,116,139,0.25)';
|
||||
ctx.lineWidth = Math.min(4, 0.5 + e.weight * 0.3) / vp.scale;
|
||||
ctx.moveTo(a.x, a.y); ctx.lineTo(b.x, b.y); ctx.stroke();
|
||||
}
|
||||
|
||||
// Nodes
|
||||
for (const n of nodes) {
|
||||
const highlighted = hovered === n.id || selected === n.id || (search && matchSet.has(n.id));
|
||||
ctx.beginPath();
|
||||
ctx.arc(n.x, n.y, n.radius, 0, Math.PI * 2);
|
||||
ctx.fillStyle = highlighted ? '#fff' : n.color;
|
||||
ctx.globalAlpha = (search && !matchSet.has(n.id)) ? 0.15 : 1;
|
||||
ctx.fill();
|
||||
ctx.globalAlpha = 1;
|
||||
if (highlighted) { ctx.strokeStyle = n.color; ctx.lineWidth = 2.5 / vp.scale; ctx.stroke(); }
|
||||
}
|
||||
|
||||
// Labels — show more labels when zoomed in
|
||||
const labelThreshold = Math.max(1, Math.round(3 / vp.scale));
|
||||
const fontSize = Math.max(8, Math.round(11 / vp.scale));
|
||||
ctx.font = `${fontSize}px Inter, sans-serif`;
|
||||
ctx.textAlign = 'center';
|
||||
for (const n of nodes) {
|
||||
const show = hovered === n.id || selected === n.id
|
||||
|| (search && matchSet.has(n.id)) || n.count >= labelThreshold;
|
||||
if (!show) continue;
|
||||
ctx.fillStyle = (search && !matchSet.has(n.id)) ? 'rgba(241,245,249,0.15)' : '#f1f5f9';
|
||||
ctx.fillText(n.label, n.x, n.y - n.radius - 5);
|
||||
}
|
||||
|
||||
ctx.restore();
|
||||
}
|
||||
|
||||
// ── Hit-test helper (viewport-aware) ─────────────────────────────────
|
||||
|
||||
function screenToWorld(
|
||||
canvas: HTMLCanvasElement, clientX: number, clientY: number, vp: Viewport,
|
||||
): { wx: number; wy: number } {
|
||||
const rect = canvas.getBoundingClientRect();
|
||||
const cssToCanvas_x = canvas.width / rect.width;
|
||||
const cssToCanvas_y = canvas.height / rect.height;
|
||||
const cx = (clientX - rect.left) * cssToCanvas_x;
|
||||
const cy = (clientY - rect.top) * cssToCanvas_y;
|
||||
return { wx: (cx - vp.x) / vp.scale, wy: (cy - vp.y) / vp.scale };
|
||||
}
|
||||
|
||||
function hitTest(
|
||||
graph: Graph, canvas: HTMLCanvasElement, clientX: number, clientY: number,
|
||||
vp: Viewport,
|
||||
): GNode | null {
|
||||
const { wx, wy } = screenToWorld(canvas, clientX, clientY, vp);
|
||||
for (const n of graph.nodes) {
|
||||
const dx = n.x - wx, dy = n.y - wy;
|
||||
if (dx * dx + dy * dy < (n.radius + 4) ** 2) return n;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// ── Component ────────────────────────────────────────────────────────
|
||||
|
||||
export default function NetworkMap() {
|
||||
// Hunt selector
|
||||
const [huntList, setHuntList] = useState<Hunt[]>([]);
|
||||
const [selectedHuntId, setSelectedHuntId] = useState('');
|
||||
|
||||
// Graph state
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [progress, setProgress] = useState('');
|
||||
const [error, setError] = useState('');
|
||||
const [graph, setGraph] = useState<Graph | null>(null);
|
||||
const [hovered, setHovered] = useState<string | null>(null);
|
||||
const [selectedNode, setSelectedNode] = useState<GNode | null>(null);
|
||||
const [search, setSearch] = useState('');
|
||||
const [dsCount, setDsCount] = useState(0);
|
||||
const [totalRows, setTotalRows] = useState(0);
|
||||
|
||||
// Node type filters
|
||||
const [visibleTypes, setVisibleTypes] = useState<Set<NodeType>>(
|
||||
new Set<NodeType>(['ip', 'hostname', 'domain', 'url']),
|
||||
);
|
||||
|
||||
// Canvas sizing
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
const wrapperRef = useRef<HTMLDivElement>(null);
|
||||
const [canvasSize, setCanvasSize] = useState({ w: 900, h: 600 });
|
||||
|
||||
// Viewport (zoom / pan)
|
||||
const vpRef = useRef<Viewport>({ x: 0, y: 0, scale: 1 });
|
||||
const [vpScale, setVpScale] = useState(1); // for UI display only
|
||||
const isPanning = useRef(false);
|
||||
const panStart = useRef({ x: 0, y: 0 });
|
||||
|
||||
// Popover anchor
|
||||
const [popoverAnchor, setPopoverAnchor] = useState<{ top: number; left: number } | null>(null);
|
||||
|
||||
// ── Load hunts on mount ────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
hunts.list(0, 200).then(r => setHuntList(r.hunts)).catch(() => {});
|
||||
}, []);
|
||||
|
||||
// ── Resize observer ────────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
const el = wrapperRef.current;
|
||||
if (!el) return;
|
||||
const ro = new ResizeObserver(entries => {
|
||||
for (const entry of entries) {
|
||||
const w = Math.round(entry.contentRect.width);
|
||||
if (w > 100) setCanvasSize({ w, h: Math.max(450, Math.round(w * 0.55)) });
|
||||
}
|
||||
});
|
||||
ro.observe(el);
|
||||
return () => ro.disconnect();
|
||||
}, []);
|
||||
|
||||
// ── Load graph for selected hunt ──────────────────────────────────
|
||||
const loadGraph = useCallback(async (huntId: string) => {
|
||||
if (!huntId) return;
|
||||
setLoading(true); setError(''); setGraph(null);
|
||||
setSelectedNode(null); setPopoverAnchor(null);
|
||||
try {
|
||||
setProgress('Fetching datasets for hunt…');
|
||||
const dsRes = await datasets.list(0, 500, huntId);
|
||||
const dsList = dsRes.datasets;
|
||||
setDsCount(dsList.length);
|
||||
|
||||
if (dsList.length === 0) {
|
||||
setError('This hunt has no datasets. Upload CSV files to this hunt first.');
|
||||
setLoading(false); setProgress('');
|
||||
return;
|
||||
}
|
||||
|
||||
const allBatches: RowBatch[] = [];
|
||||
let rowTotal = 0;
|
||||
|
||||
for (let i = 0; i < dsList.length; i++) {
|
||||
const ds = dsList[i];
|
||||
setProgress(`Loading ${ds.name} (${i + 1}/${dsList.length})…`);
|
||||
try {
|
||||
const detail = await datasets.get(ds.id);
|
||||
const ioc = detail.ioc_columns || {};
|
||||
const hasIoc = Object.values(ioc).some(t => {
|
||||
const typ = Array.isArray(t) ? t[0] : t;
|
||||
return typ === 'ip' || typ === 'hostname' || typ === 'domain' || typ === 'url';
|
||||
});
|
||||
if (hasIoc) {
|
||||
const r = await datasets.rows(ds.id, 0, 5000);
|
||||
allBatches.push({ rows: r.rows, iocColumns: ioc, dsName: ds.name, ds: detail });
|
||||
rowTotal += r.rows.length;
|
||||
}
|
||||
} catch { /* skip failed datasets */ }
|
||||
}
|
||||
|
||||
setTotalRows(rowTotal);
|
||||
|
||||
if (allBatches.length === 0) {
|
||||
setError('No datasets in this hunt contain IP/hostname/domain IOC columns.');
|
||||
setLoading(false); setProgress('');
|
||||
return;
|
||||
}
|
||||
|
||||
setProgress('Building graph…');
|
||||
const g = buildGraph(allBatches, canvasSize.w, canvasSize.h);
|
||||
if (g.nodes.length === 0) {
|
||||
setError('No network nodes found in the data.');
|
||||
} else {
|
||||
simulate(g, canvasSize.w / 2, canvasSize.h / 2);
|
||||
setGraph(g);
|
||||
}
|
||||
} catch (e: any) { setError(e.message); }
|
||||
setLoading(false); setProgress('');
|
||||
}, [canvasSize]);
|
||||
|
||||
// When hunt changes, load graph
|
||||
useEffect(() => {
|
||||
if (selectedHuntId) loadGraph(selectedHuntId);
|
||||
}, [selectedHuntId, loadGraph]);
|
||||
|
||||
// Reset viewport when graph changes
|
||||
useEffect(() => {
|
||||
vpRef.current = { x: 0, y: 0, scale: 1 };
|
||||
setVpScale(1);
|
||||
}, [graph]);
|
||||
|
||||
// Filtered graph — only visible node types + edges between them
|
||||
const filteredGraph = useMemo<Graph | null>(() => {
|
||||
if (!graph) return null;
|
||||
const nodes = graph.nodes.filter(n => visibleTypes.has(n.meta.type));
|
||||
const nodeIds = new Set(nodes.map(n => n.id));
|
||||
const edges = graph.edges.filter(e => nodeIds.has(e.source) && nodeIds.has(e.target));
|
||||
return { nodes, edges };
|
||||
}, [graph, visibleTypes]);
|
||||
|
||||
// Toggle a node type filter
|
||||
const toggleType = useCallback((t: NodeType) => {
|
||||
setVisibleTypes(prev => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(t)) {
|
||||
// Don't allow all to be hidden
|
||||
if (next.size > 1) next.delete(t);
|
||||
} else {
|
||||
next.add(t);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Redraw helper — uses filteredGraph
|
||||
const redraw = useCallback(() => {
|
||||
if (!filteredGraph || !canvasRef.current) return;
|
||||
const ctx = canvasRef.current.getContext('2d');
|
||||
if (ctx) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vpRef.current);
|
||||
}, [filteredGraph, hovered, selectedNode, search]);
|
||||
|
||||
// Redraw on every render-affecting state change
|
||||
useEffect(() => { redraw(); }, [redraw]);
|
||||
|
||||
// ── Mouse wheel → zoom ─────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
const canvas = canvasRef.current;
|
||||
if (!canvas) return;
|
||||
const onWheel = (e: WheelEvent) => {
|
||||
e.preventDefault();
|
||||
const vp = vpRef.current;
|
||||
const rect = canvas.getBoundingClientRect();
|
||||
const cssToCanvasX = canvas.width / rect.width;
|
||||
const cssToCanvasY = canvas.height / rect.height;
|
||||
// Mouse position in canvas pixel coords
|
||||
const mx = (e.clientX - rect.left) * cssToCanvasX;
|
||||
const my = (e.clientY - rect.top) * cssToCanvasY;
|
||||
|
||||
const zoomFactor = e.deltaY < 0 ? 1.12 : 1 / 1.12;
|
||||
const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * zoomFactor));
|
||||
// Zoom toward cursor: adjust offset so world-point under cursor stays fixed
|
||||
vp.x = mx - (mx - vp.x) * (newScale / vp.scale);
|
||||
vp.y = my - (my - vp.y) * (newScale / vp.scale);
|
||||
vp.scale = newScale;
|
||||
setVpScale(newScale);
|
||||
// Immediate redraw (bypass React state for smoothness)
|
||||
const ctx = canvas.getContext('2d');
|
||||
if (ctx && filteredGraph) drawGraph(ctx, filteredGraph, hovered, selectedNode?.id ?? null, search, vp);
|
||||
};
|
||||
canvas.addEventListener('wheel', onWheel, { passive: false });
|
||||
return () => canvas.removeEventListener('wheel', onWheel);
|
||||
}, [filteredGraph, hovered, selectedNode, search]);
|
||||
|
||||
// ── Mouse drag → pan ───────────────────────────────────────────────
|
||||
const onMouseDown = useCallback((e: React.MouseEvent<HTMLCanvasElement>) => {
|
||||
if (!filteredGraph || !canvasRef.current) return;
|
||||
const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||
if (!node) {
|
||||
isPanning.current = true;
|
||||
panStart.current = { x: e.clientX, y: e.clientY };
|
||||
}
|
||||
}, [filteredGraph]);
|
||||
|
||||
const onMouseMove = useCallback((e: React.MouseEvent<HTMLCanvasElement>) => {
|
||||
if (!filteredGraph || !canvasRef.current) return;
|
||||
|
||||
if (isPanning.current) {
|
||||
const vp = vpRef.current;
|
||||
const rect = canvasRef.current.getBoundingClientRect();
|
||||
const cssToCanvasX = canvasRef.current.width / rect.width;
|
||||
const cssToCanvasY = canvasRef.current.height / rect.height;
|
||||
vp.x += (e.clientX - panStart.current.x) * cssToCanvasX;
|
||||
vp.y += (e.clientY - panStart.current.y) * cssToCanvasY;
|
||||
panStart.current = { x: e.clientX, y: e.clientY };
|
||||
redraw();
|
||||
return;
|
||||
}
|
||||
|
||||
const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||
setHovered(node?.id ?? null);
|
||||
}, [filteredGraph, redraw]);
|
||||
|
||||
const onMouseUp = useCallback(() => {
|
||||
isPanning.current = false;
|
||||
}, []);
|
||||
|
||||
// ── Mouse click → select node + show popover ─────────────────────
|
||||
const onClick = useCallback((e: React.MouseEvent<HTMLCanvasElement>) => {
|
||||
if (!filteredGraph || !canvasRef.current) return;
|
||||
const node = hitTest(filteredGraph, canvasRef.current, e.clientX, e.clientY, vpRef.current);
|
||||
if (node) {
|
||||
setSelectedNode(node);
|
||||
setPopoverAnchor({ top: e.clientY, left: e.clientX });
|
||||
} else {
|
||||
setSelectedNode(null);
|
||||
setPopoverAnchor(null);
|
||||
}
|
||||
}, [filteredGraph]);
|
||||
|
||||
const closePopover = () => { setSelectedNode(null); setPopoverAnchor(null); };
|
||||
|
||||
// ── Zoom controls ──────────────────────────────────────────────────
|
||||
const zoomBy = useCallback((factor: number) => {
|
||||
const vp = vpRef.current;
|
||||
const cw = canvasSize.w, ch = canvasSize.h;
|
||||
const newScale = Math.min(MAX_ZOOM, Math.max(MIN_ZOOM, vp.scale * factor));
|
||||
// Zoom toward canvas center
|
||||
vp.x = cw / 2 - (cw / 2 - vp.x) * (newScale / vp.scale);
|
||||
vp.y = ch / 2 - (ch / 2 - vp.y) * (newScale / vp.scale);
|
||||
vp.scale = newScale;
|
||||
setVpScale(newScale);
|
||||
redraw();
|
||||
}, [canvasSize, redraw]);
|
||||
|
||||
const resetView = useCallback(() => {
|
||||
vpRef.current = { x: 0, y: 0, scale: 1 };
|
||||
setVpScale(1);
|
||||
redraw();
|
||||
}, [redraw]);
|
||||
|
||||
// Count connections for selected node
|
||||
const connectionCount = selectedNode && filteredGraph
|
||||
? filteredGraph.edges.filter(e => e.source === selectedNode.id || e.target === selectedNode.id).length
|
||||
: 0;
|
||||
|
||||
// ── Render ─────────────────────────────────────────────────────────
|
||||
return (
|
||||
<Box>
|
||||
{/* Header row */}
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 2 }} flexWrap="wrap" gap={1}>
|
||||
<Typography variant="h5">Network Map</Typography>
|
||||
<Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap">
|
||||
<FormControl size="small" sx={{ minWidth: 220 }}>
|
||||
<InputLabel id="hunt-selector-label">Hunt</InputLabel>
|
||||
<Select
|
||||
labelId="hunt-selector-label"
|
||||
value={selectedHuntId}
|
||||
label="Hunt"
|
||||
onChange={e => setSelectedHuntId(e.target.value)}
|
||||
>
|
||||
{huntList.map(h => (
|
||||
<MenuItem key={h.id} value={h.id}>
|
||||
{h.name} ({h.dataset_count} datasets)
|
||||
</MenuItem>
|
||||
))}
|
||||
</Select>
|
||||
</FormControl>
|
||||
<TextField size="small" placeholder="Search node…" value={search}
|
||||
onChange={e => setSearch(e.target.value)} sx={{ width: 200 }} />
|
||||
<Button variant="outlined" startIcon={<RefreshIcon />}
|
||||
onClick={() => loadGraph(selectedHuntId)}
|
||||
disabled={loading || !selectedHuntId} size="small">
|
||||
Refresh
|
||||
</Button>
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
{/* Loading indicator */}
|
||||
{loading && (
|
||||
<Paper sx={{ p: 2, mb: 2 }}>
|
||||
<Typography variant="body2" color="text.secondary" gutterBottom>{progress}</Typography>
|
||||
<LinearProgress />
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{error && <Alert severity="warning" sx={{ mb: 2 }}>{error}</Alert>}
|
||||
|
||||
{/* Legend — clickable type filters */}
|
||||
{graph && filteredGraph && (
|
||||
<Stack direction="row" spacing={1} sx={{ mb: 1 }} flexWrap="wrap" gap={0.5} alignItems="center">
|
||||
<Chip label={`${dsCount} datasets`} size="small" variant="outlined" />
|
||||
<Chip label={`${totalRows.toLocaleString()} rows`} size="small" variant="outlined" />
|
||||
<Chip label={`${filteredGraph.nodes.length} nodes`} size="small" color="primary" variant="outlined" />
|
||||
<Chip label={`${filteredGraph.edges.length} edges`} size="small" color="secondary" variant="outlined" />
|
||||
<Divider orientation="vertical" flexItem />
|
||||
{([['ip', 'IP'], ['hostname', 'Host'], ['domain', 'Domain'], ['url', 'URL']] as [NodeType, string][]).map(([type, label]) => {
|
||||
const active = visibleTypes.has(type);
|
||||
const count = graph.nodes.filter(n => n.meta.type === type).length;
|
||||
return (
|
||||
<Chip
|
||||
key={type}
|
||||
label={`${label} (${count})`}
|
||||
size="small"
|
||||
onClick={() => toggleType(type)}
|
||||
sx={{
|
||||
bgcolor: active ? TYPE_COLORS[type] : 'transparent',
|
||||
color: active ? '#fff' : TYPE_COLORS[type],
|
||||
border: `2px solid ${TYPE_COLORS[type]}`,
|
||||
fontWeight: 600,
|
||||
cursor: 'pointer',
|
||||
opacity: active ? 1 : 0.5,
|
||||
transition: 'all 0.15s ease',
|
||||
'&:hover': { opacity: 1 },
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</Stack>
|
||||
)}
|
||||
|
||||
{/* Canvas */}
|
||||
{filteredGraph && (
|
||||
<Paper ref={wrapperRef} sx={{ p: 1, position: 'relative', backgroundColor: '#0f172a' }}>
|
||||
<canvas
|
||||
ref={canvasRef}
|
||||
width={canvasSize.w} height={canvasSize.h}
|
||||
style={{
|
||||
width: '100%', height: canvasSize.h,
|
||||
cursor: isPanning.current ? 'grabbing' : hovered ? 'pointer' : 'grab',
|
||||
}}
|
||||
onMouseDown={onMouseDown}
|
||||
onMouseMove={onMouseMove}
|
||||
onMouseUp={onMouseUp}
|
||||
onMouseLeave={() => { isPanning.current = false; setHovered(null); }}
|
||||
onClick={onClick}
|
||||
/>
|
||||
{/* Zoom controls overlay */}
|
||||
<Stack
|
||||
direction="column" spacing={0.5}
|
||||
sx={{ position: 'absolute', top: 12, right: 12, zIndex: 2 }}
|
||||
>
|
||||
<IconButton size="small" onClick={() => zoomBy(1.3)}
|
||||
sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#f1f5f9', '&:hover': { bgcolor: 'rgba(51,65,85,0.95)' } }}
|
||||
aria-label="Zoom in"><ZoomInIcon fontSize="small" /></IconButton>
|
||||
<IconButton size="small" onClick={() => zoomBy(1 / 1.3)}
|
||||
sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#f1f5f9', '&:hover': { bgcolor: 'rgba(51,65,85,0.95)' } }}
|
||||
aria-label="Zoom out"><ZoomOutIcon fontSize="small" /></IconButton>
|
||||
<IconButton size="small" onClick={resetView}
|
||||
sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#f1f5f9', '&:hover': { bgcolor: 'rgba(51,65,85,0.95)' } }}
|
||||
aria-label="Reset view"><CenterFocusStrongIcon fontSize="small" /></IconButton>
|
||||
<Chip label={`${Math.round(vpScale * 100)}%`} size="small"
|
||||
sx={{ bgcolor: 'rgba(30,41,59,0.85)', color: '#94a3b8', fontSize: 11, height: 22 }} />
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Node detail popover */}
|
||||
<Popover
|
||||
open={Boolean(selectedNode && popoverAnchor)}
|
||||
anchorReference="anchorPosition"
|
||||
anchorPosition={popoverAnchor ?? undefined}
|
||||
onClose={closePopover}
|
||||
transformOrigin={{ vertical: 'top', horizontal: 'left' }}
|
||||
slotProps={{ paper: { sx: { p: 2, minWidth: 280, maxWidth: 400 } } }}
|
||||
>
|
||||
{selectedNode && (
|
||||
<Box>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: 1 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1}>
|
||||
<Typography variant="subtitle1" fontWeight={700}>{selectedNode.label}</Typography>
|
||||
<Chip label={selectedNode.meta.type.toUpperCase()} size="small"
|
||||
sx={{ bgcolor: TYPE_COLORS[selectedNode.meta.type], color: '#fff', fontWeight: 600, fontSize: 11 }} />
|
||||
</Stack>
|
||||
<IconButton size="small" onClick={closePopover} aria-label="close"><CloseIcon fontSize="small" /></IconButton>
|
||||
</Stack>
|
||||
<Divider sx={{ mb: 1.5 }} />
|
||||
|
||||
{/* Hostnames */}
|
||||
<Typography variant="caption" color="text.secondary" fontWeight={600}>Hostname</Typography>
|
||||
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||
{selectedNode.meta.hostnames.length > 0
|
||||
? selectedNode.meta.hostnames.join(', ')
|
||||
: <em>Unknown</em>}
|
||||
</Typography>
|
||||
|
||||
{/* IPs */}
|
||||
<Typography variant="caption" color="text.secondary" fontWeight={600}>IP Address</Typography>
|
||||
<Typography variant="body2" sx={{ mb: 1, fontFamily: 'monospace' }}>
|
||||
{selectedNode.meta.ips.length > 0
|
||||
? selectedNode.meta.ips.join(', ')
|
||||
: (selectedNode.meta.type === 'ip' ? selectedNode.label : <em>Unknown</em>)}
|
||||
</Typography>
|
||||
|
||||
{/* OS */}
|
||||
<Typography variant="caption" color="text.secondary" fontWeight={600}>Operating System</Typography>
|
||||
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||
{selectedNode.meta.os.length > 0
|
||||
? selectedNode.meta.os.join(', ')
|
||||
: <em>Unknown</em>}
|
||||
</Typography>
|
||||
|
||||
<Divider sx={{ my: 1 }} />
|
||||
|
||||
{/* Stats */}
|
||||
<Stack direction="row" spacing={1} flexWrap="wrap" gap={0.5}>
|
||||
<Chip label={`${selectedNode.count} occurrences`} size="small" variant="outlined" />
|
||||
<Chip label={`${connectionCount} connections`} size="small" variant="outlined" />
|
||||
</Stack>
|
||||
|
||||
{/* Datasets */}
|
||||
{selectedNode.meta.datasets.length > 0 && (
|
||||
<Box sx={{ mt: 1.5 }}>
|
||||
<Typography variant="caption" color="text.secondary" fontWeight={600}>Seen in datasets</Typography>
|
||||
<Stack direction="row" spacing={0.5} flexWrap="wrap" gap={0.5} sx={{ mt: 0.5 }}>
|
||||
{selectedNode.meta.datasets.map(d => (
|
||||
<Chip key={d} label={d} size="small" variant="outlined" />
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
)}
|
||||
</Popover>
|
||||
|
||||
{/* Empty states */}
|
||||
{!selectedHuntId && !loading && (
|
||||
<Paper ref={wrapperRef} sx={{ p: 6, textAlign: 'center' }}>
|
||||
<Typography variant="h6" color="text.secondary" gutterBottom>
|
||||
Select a hunt to visualize its network
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary">
|
||||
Choose a hunt from the dropdown above. The map will display IP addresses,
|
||||
hostnames, and domains found across the hunt's datasets, with connections
|
||||
showing co-occurrence in the same log rows.
|
||||
</Typography>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{selectedHuntId && !graph && !loading && !error && (
|
||||
<Paper sx={{ p: 6, textAlign: 'center' }}>
|
||||
<Typography color="text.secondary">
|
||||
No network data to display. Upload datasets with IP/hostname columns to this hunt.
|
||||
</Typography>
|
||||
</Paper>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
86
frontend/src/theme.ts
Normal file
86
frontend/src/theme.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { createTheme } from '@mui/material/styles';
|
||||
|
||||
const theme = createTheme({
|
||||
palette: {
|
||||
mode: 'dark',
|
||||
primary: {
|
||||
main: '#60a5fa', // blue-400
|
||||
light: '#93c5fd',
|
||||
dark: '#2563eb',
|
||||
},
|
||||
secondary: {
|
||||
main: '#f472b6', // pink-400
|
||||
light: '#f9a8d4',
|
||||
dark: '#db2777',
|
||||
},
|
||||
error: {
|
||||
main: '#ef4444',
|
||||
},
|
||||
warning: {
|
||||
main: '#f59e0b',
|
||||
},
|
||||
success: {
|
||||
main: '#10b981',
|
||||
},
|
||||
info: {
|
||||
main: '#06b6d4',
|
||||
},
|
||||
background: {
|
||||
default: '#0f172a', // slate-900
|
||||
paper: '#1e293b', // slate-800
|
||||
},
|
||||
text: {
|
||||
primary: '#f1f5f9', // slate-100
|
||||
secondary: '#94a3b8', // slate-400
|
||||
},
|
||||
divider: '#334155', // slate-700
|
||||
},
|
||||
typography: {
|
||||
fontFamily: '"Inter", "Roboto", "Helvetica Neue", Arial, sans-serif',
|
||||
h4: { fontWeight: 700 },
|
||||
h5: { fontWeight: 600 },
|
||||
h6: { fontWeight: 600 },
|
||||
},
|
||||
shape: {
|
||||
borderRadius: 8,
|
||||
},
|
||||
components: {
|
||||
MuiPaper: {
|
||||
defaultProps: { elevation: 0 },
|
||||
styleOverrides: {
|
||||
root: {
|
||||
backgroundImage: 'none',
|
||||
border: '1px solid',
|
||||
borderColor: '#334155',
|
||||
},
|
||||
},
|
||||
},
|
||||
MuiButton: {
|
||||
defaultProps: { disableElevation: true },
|
||||
styleOverrides: {
|
||||
root: { textTransform: 'none', fontWeight: 600 },
|
||||
},
|
||||
},
|
||||
MuiChip: {
|
||||
styleOverrides: {
|
||||
root: { fontWeight: 500 },
|
||||
},
|
||||
},
|
||||
MuiDrawer: {
|
||||
styleOverrides: {
|
||||
paper: { borderRight: '1px solid #334155' },
|
||||
},
|
||||
},
|
||||
MuiAppBar: {
|
||||
styleOverrides: {
|
||||
root: {
|
||||
backgroundImage: 'none',
|
||||
backgroundColor: '#1e293b',
|
||||
borderBottom: '1px solid #334155',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export default theme;
|
||||
Reference in New Issue
Block a user