mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests package
|
||||
108
backend/tests/conftest.py
Normal file
108
backend/tests/conftest.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Shared pytest fixtures for ThreatHunt tests.
|
||||
|
||||
Provides:
|
||||
- Async test database (in-memory SQLite)
|
||||
- Test client (httpx AsyncClient on the FastAPI app)
|
||||
- Factory functions for creating test hunts, datasets, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Force test database
|
||||
os.environ["TH_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||
os.environ["TH_JWT_SECRET"] = "test-secret-key-for-tests"
|
||||
|
||||
from app.db.engine import Base, get_db
|
||||
from app.main import app
|
||||
|
||||
|
||||
# ── Database fixtures ─────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an event loop for the test session."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def test_engine():
|
||||
"""Create test database engine."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create a fresh database session for each test."""
|
||||
async_session = sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""Create an async test client with overridden DB dependency."""
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ── Factory helpers ───────────────────────────────────────────────────
|
||||
|
||||
def make_csv_bytes(
|
||||
columns: list[str],
|
||||
rows: list[list[str]],
|
||||
delimiter: str = ",",
|
||||
) -> bytes:
|
||||
"""Create CSV content as bytes for upload tests."""
|
||||
lines = [delimiter.join(columns)]
|
||||
for row in rows:
|
||||
lines.append(delimiter.join(str(v) for v in row))
|
||||
return "\n".join(lines).encode("utf-8")
|
||||
|
||||
|
||||
SAMPLE_CSV = make_csv_bytes(
|
||||
["timestamp", "hostname", "src_ip", "dst_ip", "process_name", "command_line"],
|
||||
[
|
||||
["2025-01-15T10:30:00Z", "DESKTOP-ABC", "192.168.1.100", "10.0.0.50", "cmd.exe", "cmd /c whoami"],
|
||||
["2025-01-15T10:31:00Z", "DESKTOP-ABC", "192.168.1.100", "10.0.0.51", "powershell.exe", "powershell -enc SGVsbG8="],
|
||||
["2025-01-15T10:32:00Z", "DESKTOP-XYZ", "192.168.1.101", "8.8.8.8", "chrome.exe", "chrome.exe --no-sandbox"],
|
||||
["2025-01-15T10:33:00Z", "DESKTOP-ABC", "192.168.1.100", "203.0.113.5", "svchost.exe", "svchost.exe -k netsvcs"],
|
||||
["2025-01-15T10:34:00Z", "SERVER-DC01", "10.0.0.1", "10.0.0.50", "lsass.exe", "lsass.exe"],
|
||||
],
|
||||
)
|
||||
|
||||
SAMPLE_HASH_CSV = make_csv_bytes(
|
||||
["filename", "md5", "sha256", "size"],
|
||||
[
|
||||
["malware.exe", "d41d8cd98f00b204e9800998ecf8427e", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", "1024"],
|
||||
["benign.dll", "098f6bcd4621d373cade4e832627b4f6", "5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8", "2048"],
|
||||
],
|
||||
)
|
||||
117
backend/tests/test_agents.py
Normal file
117
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for model registry and task router."""
|
||||
|
||||
import pytest
|
||||
from app.agents.registry import (
|
||||
ModelRegistry, ModelEntry, Capability, Tier, Node,
|
||||
registry, ROADRUNNER_MODELS, WILE_MODELS,
|
||||
)
|
||||
from app.agents.router import TaskRouter, TaskType, task_router
|
||||
|
||||
|
||||
class TestModelRegistry:
|
||||
"""Tests for the model registry."""
|
||||
|
||||
def test_registry_has_models(self):
|
||||
assert len(registry.models) > 0
|
||||
assert len(ROADRUNNER_MODELS) > 0
|
||||
assert len(WILE_MODELS) > 0
|
||||
|
||||
def test_find_by_capability(self):
|
||||
chat_models = registry.find(capability=Capability.CHAT)
|
||||
assert len(chat_models) > 0
|
||||
for m in chat_models:
|
||||
assert Capability.CHAT in m.capabilities
|
||||
|
||||
def test_find_code_models(self):
|
||||
code_models = registry.find(capability=Capability.CODE)
|
||||
assert len(code_models) > 0
|
||||
|
||||
def test_find_vision_models(self):
|
||||
vision_models = registry.find(capability=Capability.VISION)
|
||||
assert len(vision_models) > 0
|
||||
|
||||
def test_find_embedding_models(self):
|
||||
embed_models = registry.find(capability=Capability.EMBEDDING)
|
||||
assert len(embed_models) > 0
|
||||
|
||||
def test_find_by_node(self):
|
||||
wile_models = registry.find(node=Node.WILE)
|
||||
rr_models = registry.find(node=Node.ROADRUNNER)
|
||||
assert len(wile_models) > 0
|
||||
assert len(rr_models) > 0
|
||||
|
||||
def test_find_heavy_models(self):
|
||||
heavy = registry.find(tier=Tier.HEAVY)
|
||||
assert len(heavy) > 0
|
||||
for m in heavy:
|
||||
assert m.tier == Tier.HEAVY
|
||||
|
||||
def test_get_best(self):
|
||||
best = registry.get_best(Capability.CHAT, prefer_tier=Tier.FAST)
|
||||
assert best is not None
|
||||
assert Capability.CHAT in best.capabilities
|
||||
|
||||
def test_get_best_vision_on_roadrunner(self):
|
||||
best = registry.get_best(Capability.VISION, prefer_node=Node.ROADRUNNER)
|
||||
assert best is not None
|
||||
assert Capability.VISION in best.capabilities
|
||||
|
||||
def test_to_dict(self):
|
||||
result = registry.to_dict()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
assert "name" in result[0]
|
||||
assert "capabilities" in result[0]
|
||||
|
||||
|
||||
class TestTaskRouter:
|
||||
"""Tests for the task router."""
|
||||
|
||||
def test_route_quick_chat(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT)
|
||||
assert decision.model
|
||||
assert decision.node
|
||||
|
||||
def test_route_deep_analysis(self):
|
||||
decision = task_router.route(TaskType.DEEP_ANALYSIS)
|
||||
assert decision.model
|
||||
# Deep should route to heavy model
|
||||
assert decision.task_type == TaskType.DEEP_ANALYSIS
|
||||
|
||||
def test_route_code_analysis(self):
|
||||
decision = task_router.route(TaskType.CODE_ANALYSIS)
|
||||
assert decision.model
|
||||
assert "coder" in decision.model.lower() or "code" in decision.model.lower()
|
||||
|
||||
def test_route_vision(self):
|
||||
decision = task_router.route(TaskType.VISION)
|
||||
assert decision.model
|
||||
assert decision.node == Node.ROADRUNNER
|
||||
|
||||
def test_route_with_model_override(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT, model_override="llama3.1:latest")
|
||||
assert decision.model == "llama3.1:latest"
|
||||
|
||||
def test_route_unknown_model_to_cluster(self):
|
||||
decision = task_router.route(TaskType.QUICK_CHAT, model_override="nonexistent-model:99b")
|
||||
assert decision.node == Node.CLUSTER
|
||||
assert decision.provider_type == "openwebui"
|
||||
|
||||
def test_classify_code_task(self):
|
||||
assert task_router.classify_task("deobfuscate this powershell script") == TaskType.CODE_ANALYSIS
|
||||
assert task_router.classify_task("decode this base64 payload") == TaskType.CODE_ANALYSIS
|
||||
|
||||
def test_classify_deep_task(self):
|
||||
assert task_router.classify_task("detailed forensic analysis of this process tree") == TaskType.DEEP_ANALYSIS
|
||||
|
||||
def test_classify_vision_task(self):
|
||||
assert task_router.classify_task("analyze this screenshot", has_image=True) == TaskType.VISION
|
||||
|
||||
def test_classify_quick_task(self):
|
||||
assert task_router.classify_task("what does this process do?") == TaskType.QUICK_CHAT
|
||||
|
||||
def test_debate_model_overrides(self):
|
||||
for task_type in [TaskType.DEBATE_PLANNER, TaskType.DEBATE_CRITIC, TaskType.DEBATE_PRAGMATIST, TaskType.DEBATE_JUDGE]:
|
||||
decision = task_router.route(task_type)
|
||||
assert decision.model
|
||||
assert decision.task_type == task_type
|
||||
189
backend/tests/test_api.py
Normal file
189
backend/tests/test_api.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for API endpoints — datasets, hunts, annotations."""
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHealthEndpoints:
|
||||
"""Test basic health endpoints."""
|
||||
|
||||
async def test_root(self, client):
|
||||
resp = await client.get("/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["service"] == "ThreatHunt API"
|
||||
assert data["status"] == "running"
|
||||
|
||||
async def test_openapi_docs(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "/api/agent/assist" in data["paths"]
|
||||
assert "/api/datasets/upload" in data["paths"]
|
||||
assert "/api/hunts" in data["paths"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHuntEndpoints:
|
||||
"""Test hunt CRUD operations."""
|
||||
|
||||
async def test_create_hunt(self, client):
|
||||
resp = await client.post("/api/hunts", json={
|
||||
"name": "Test Hunt",
|
||||
"description": "Testing hunt creation",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Test Hunt"
|
||||
assert data["status"] == "active"
|
||||
assert data["id"]
|
||||
|
||||
async def test_list_hunts(self, client):
|
||||
# Create a hunt first
|
||||
await client.post("/api/hunts", json={"name": "Hunt 1"})
|
||||
await client.post("/api/hunts", json={"name": "Hunt 2"})
|
||||
|
||||
resp = await client.get("/api/hunts")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 2
|
||||
|
||||
async def test_get_hunt(self, client):
|
||||
# Create
|
||||
create_resp = await client.post("/api/hunts", json={"name": "Specific Hunt"})
|
||||
hunt_id = create_resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/hunts/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Specific Hunt"
|
||||
|
||||
async def test_update_hunt(self, client):
|
||||
create_resp = await client.post("/api/hunts", json={"name": "Original"})
|
||||
hunt_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.put(f"/api/hunts/{hunt_id}", json={
|
||||
"name": "Updated",
|
||||
"status": "closed",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Updated"
|
||||
assert resp.json()["status"] == "closed"
|
||||
|
||||
async def test_get_nonexistent_hunt(self, client):
|
||||
resp = await client.get("/api/hunts/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDatasetEndpoints:
|
||||
"""Test dataset upload and retrieval."""
|
||||
|
||||
async def test_upload_csv(self, client):
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
resp = await client.post(
|
||||
"/api/datasets/upload",
|
||||
files=files,
|
||||
params={"name": "Test Dataset"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "Test Dataset"
|
||||
assert data["row_count"] == 5
|
||||
assert "timestamp" in data["columns"]
|
||||
|
||||
async def test_upload_invalid_extension(self, client):
|
||||
files = {"file": ("bad.exe", io.BytesIO(b"not csv"), "application/octet-stream")}
|
||||
resp = await client.post("/api/datasets/upload", files=files)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_upload_empty_file(self, client):
|
||||
files = {"file": ("empty.csv", io.BytesIO(b""), "text/csv")}
|
||||
resp = await client.post("/api/datasets/upload", files=files)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_list_datasets(self, client):
|
||||
# Upload first
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
await client.post("/api/datasets/upload", files=files, params={"name": "DS1"})
|
||||
|
||||
resp = await client.get("/api/datasets")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
async def test_get_dataset_rows(self, client):
|
||||
files = {"file": ("test.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload_resp = await client.post("/api/datasets/upload", files=files, params={"name": "RowTest"})
|
||||
ds_id = upload_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/datasets/{ds_id}/rows")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 5
|
||||
assert len(data["rows"]) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAnnotationEndpoints:
|
||||
"""Test annotation CRUD."""
|
||||
|
||||
async def test_create_annotation(self, client):
|
||||
resp = await client.post("/api/annotations", json={
|
||||
"text": "Suspicious process detected",
|
||||
"severity": "high",
|
||||
"tag": "suspicious",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["text"] == "Suspicious process detected"
|
||||
assert data["severity"] == "high"
|
||||
|
||||
async def test_list_annotations(self, client):
|
||||
await client.post("/api/annotations", json={"text": "Ann 1", "severity": "info"})
|
||||
await client.post("/api/annotations", json={"text": "Ann 2", "severity": "critical"})
|
||||
|
||||
resp = await client.get("/api/annotations")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["total"] >= 2
|
||||
|
||||
async def test_filter_annotations_by_severity(self, client):
|
||||
await client.post("/api/annotations", json={"text": "Critical finding", "severity": "critical"})
|
||||
|
||||
resp = await client.get("/api/annotations", params={"severity": "critical"})
|
||||
assert resp.status_code == 200
|
||||
for ann in resp.json()["annotations"]:
|
||||
assert ann["severity"] == "critical"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestHypothesisEndpoints:
|
||||
"""Test hypothesis CRUD."""
|
||||
|
||||
async def test_create_hypothesis(self, client):
|
||||
resp = await client.post("/api/hypotheses", json={
|
||||
"title": "Living off the Land",
|
||||
"description": "Attacker using LOLBins for execution",
|
||||
"mitre_technique": "T1059",
|
||||
"status": "active",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["title"] == "Living off the Land"
|
||||
assert data["mitre_technique"] == "T1059"
|
||||
|
||||
async def test_update_hypothesis_status(self, client):
|
||||
create_resp = await client.post("/api/hypotheses", json={
|
||||
"title": "Test Hyp",
|
||||
"status": "draft",
|
||||
})
|
||||
hyp_id = create_resp.json()["id"]
|
||||
|
||||
resp = await client.put(f"/api/hypotheses/{hyp_id}", json={
|
||||
"status": "confirmed",
|
||||
"evidence_notes": "Confirmed via process tree analysis",
|
||||
})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "confirmed"
|
||||
104
backend/tests/test_csv_parser.py
Normal file
104
backend/tests/test_csv_parser.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
|
||||
import pytest
|
||||
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
||||
from app.services.normalizer import normalize_columns, normalize_rows, detect_ioc_columns, detect_time_range
|
||||
from tests.conftest import SAMPLE_CSV, SAMPLE_HASH_CSV, make_csv_bytes
|
||||
|
||||
|
||||
class TestCSVParser:
|
||||
"""Tests for CSV parsing."""
|
||||
|
||||
def test_parse_csv_basic(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert len(rows) == 5
|
||||
assert "timestamp" in meta["columns"]
|
||||
assert "hostname" in meta["columns"]
|
||||
assert meta["encoding"] is not None
|
||||
assert meta["delimiter"] == ","
|
||||
|
||||
def test_parse_csv_columns(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert meta["columns"] == ["timestamp", "hostname", "src_ip", "dst_ip", "process_name", "command_line"]
|
||||
|
||||
def test_parse_csv_row_data(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
assert rows[0]["hostname"] == "DESKTOP-ABC"
|
||||
assert rows[0]["src_ip"] == "192.168.1.100"
|
||||
assert rows[2]["process_name"] == "chrome.exe"
|
||||
|
||||
def test_parse_csv_hash_file(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_HASH_CSV)
|
||||
assert len(rows) == 2
|
||||
assert "md5" in meta["columns"]
|
||||
assert "sha256" in meta["columns"]
|
||||
|
||||
def test_parse_tsv(self):
|
||||
tsv_data = make_csv_bytes(
|
||||
["host", "ip", "port"],
|
||||
[["server1", "10.0.0.1", "443"], ["server2", "10.0.0.2", "80"]],
|
||||
delimiter="\t",
|
||||
)
|
||||
rows, meta = parse_csv_bytes(tsv_data)
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_parse_empty_file(self):
|
||||
with pytest.raises(Exception):
|
||||
parse_csv_bytes(b"")
|
||||
|
||||
def test_detect_encoding_utf8(self):
|
||||
enc = detect_encoding(SAMPLE_CSV)
|
||||
assert enc is not None
|
||||
assert "ascii" in enc.lower() or "utf" in enc.lower()
|
||||
|
||||
def test_infer_column_types(self):
|
||||
types = infer_column_types(
|
||||
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
|
||||
"src_ip",
|
||||
)
|
||||
assert types == "ip"
|
||||
|
||||
def test_infer_column_types_hash(self):
|
||||
types = infer_column_types(
|
||||
["d41d8cd98f00b204e9800998ecf8427e"],
|
||||
"hash",
|
||||
)
|
||||
assert types == "hash_md5"
|
||||
|
||||
|
||||
class TestNormalizer:
|
||||
"""Tests for column normalization."""
|
||||
|
||||
def test_normalize_columns(self):
|
||||
mapping = normalize_columns(["SourceAddr", "DestAddr", "ProcessName"])
|
||||
assert "SourceAddr" in mapping
|
||||
# Should map to canonical names
|
||||
assert mapping.get("SourceAddr") in ("src_ip", "source_address", None) or isinstance(mapping.get("SourceAddr"), str)
|
||||
|
||||
def test_normalize_known_columns(self):
|
||||
mapping = normalize_columns(["timestamp", "hostname", "src_ip"])
|
||||
assert mapping.get("timestamp") == "timestamp"
|
||||
assert mapping.get("hostname") == "hostname"
|
||||
assert mapping.get("src_ip") == "src_ip"
|
||||
|
||||
def test_detect_ioc_columns(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
column_mapping = normalize_columns(meta["columns"])
|
||||
iocs = detect_ioc_columns(meta["columns"], meta["column_types"], column_mapping)
|
||||
# Should detect IP columns
|
||||
assert isinstance(iocs, dict)
|
||||
|
||||
def test_detect_time_range(self):
|
||||
rows, meta = parse_csv_bytes(SAMPLE_CSV)
|
||||
column_mapping = normalize_columns(meta["columns"])
|
||||
start, end = detect_time_range(rows, column_mapping)
|
||||
# Should detect time range from timestamp column
|
||||
if start:
|
||||
assert "2025" in start
|
||||
|
||||
def test_normalize_rows(self):
|
||||
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
||||
mapping = {"SourceAddr": "src_ip", "ProcessName": "process_name"}
|
||||
normalized = normalize_rows(rows, mapping)
|
||||
assert len(normalized) == 1
|
||||
assert normalized[0].get("src_ip") == "10.0.0.1"
|
||||
199
backend/tests/test_keywords.py
Normal file
199
backend/tests/test_keywords.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Tests for AUP keyword themes, keyword CRUD, and scanner."""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_themes_empty(client: AsyncClient):
|
||||
"""Initially (no seed in tests) the themes list should be empty or seeded."""
|
||||
res = await client.get("/api/keywords/themes")
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert "themes" in data
|
||||
assert "total" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_theme(client: AsyncClient):
|
||||
res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Test Gambling", "color": "#f44336", "enabled": True,
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["name"] == "Test Gambling"
|
||||
assert data["color"] == "#f44336"
|
||||
assert data["enabled"] is True
|
||||
assert data["keyword_count"] == 0
|
||||
return data["id"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_theme(client: AsyncClient):
|
||||
await client.post("/api/keywords/themes", json={"name": "Dup Theme"})
|
||||
res = await client.post("/api/keywords/themes", json={"name": "Dup Theme"})
|
||||
assert res.status_code == 409
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_theme(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Updatable"})
|
||||
tid = create.json()["id"]
|
||||
res = await client.put(f"/api/keywords/themes/{tid}", json={
|
||||
"name": "Updated Name", "color": "#00ff00", "enabled": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["name"] == "Updated Name"
|
||||
assert data["color"] == "#00ff00"
|
||||
assert data["enabled"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_theme(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "ToDelete"})
|
||||
tid = create.json()["id"]
|
||||
res = await client.delete(f"/api/keywords/themes/{tid}")
|
||||
assert res.status_code == 204
|
||||
|
||||
# Verify gone
|
||||
check = await client.get("/api/keywords/themes")
|
||||
names = [t["name"] for t in check.json()["themes"]]
|
||||
assert "ToDelete" not in names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_theme(client: AsyncClient):
|
||||
res = await client.delete("/api/keywords/themes/nonexistent")
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keyword(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "KW Test Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
res = await client.post(f"/api/keywords/themes/{tid}/keywords", json={
|
||||
"value": "poker", "is_regex": False,
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["value"] == "poker"
|
||||
assert data["is_regex"] is False
|
||||
assert data["theme_id"] == tid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keywords_bulk(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Bulk KW Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
res = await client.post(f"/api/keywords/themes/{tid}/keywords/bulk", json={
|
||||
"values": ["steam", "epic games", "discord"],
|
||||
})
|
||||
assert res.status_code == 201
|
||||
data = res.json()
|
||||
assert data["added"] == 3
|
||||
assert data["theme_id"] == tid
|
||||
|
||||
# Verify via theme list
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
theme = [t for t in themes.json()["themes"] if t["id"] == tid][0]
|
||||
assert theme["keyword_count"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_keyword(client: AsyncClient):
|
||||
create = await client.post("/api/keywords/themes", json={"name": "Del KW Theme"})
|
||||
tid = create.json()["id"]
|
||||
|
||||
kw_res = await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "removeme"})
|
||||
kw_id = kw_res.json()["id"]
|
||||
|
||||
res = await client.delete(f"/api/keywords/keywords/{kw_id}")
|
||||
assert res.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_keyword_to_nonexistent_theme(client: AsyncClient):
|
||||
res = await client.post("/api/keywords/themes/fakeid/keywords", json={"value": "test"})
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
# ── Scanner ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_empty(client: AsyncClient):
|
||||
"""Scan with no data should return zero hits."""
|
||||
res = await client.post("/api/keywords/scan", json={})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["total_hits"] == 0
|
||||
assert data["hits"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_with_dataset(client: AsyncClient):
|
||||
"""Upload a dataset with known keywords, verify scanner finds them."""
|
||||
# Create a theme + keyword
|
||||
theme_res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Scan Test", "color": "#ff0000",
|
||||
})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||
|
||||
# Upload CSV dataset that contains "chrome.exe"
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("test_scan.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
assert upload.status_code == 200
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
# Scan
|
||||
res = await client.post("/api/keywords/scan", json={
|
||||
"dataset_ids": [ds_id],
|
||||
"theme_ids": [tid],
|
||||
"scan_hunts": False,
|
||||
"scan_annotations": False,
|
||||
"scan_messages": False,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["total_hits"] > 0
|
||||
# Verify the hit references chrome.exe
|
||||
kw_hits = [h for h in data["hits"] if h["keyword"] == "chrome.exe"]
|
||||
assert len(kw_hits) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_scan(client: AsyncClient):
|
||||
"""Quick scan endpoint should work with a dataset_id parameter."""
|
||||
# Create theme + keyword
|
||||
theme_res = await client.post("/api/keywords/themes", json={
|
||||
"name": "Quick Scan Theme", "color": "#00ff00",
|
||||
})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "powershell"})
|
||||
|
||||
# Upload dataset
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("quick_scan.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
res = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert "total_hits" in data
|
||||
# powershell should match at least one row
|
||||
assert data["total_hits"] > 0
|
||||
Reference in New Issue
Block a user