feat: Add Playbook Manager, Saved Searches, and Timeline View components

- Implemented PlaybookManager for creating and managing investigation playbooks with templates.
- Added SavedSearches component for managing bookmarked queries and recurring scans.
- Introduced TimelineView for visualizing forensic event timelines with zoomable charts.
- Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency.
- Updated frontend components for better user experience and performance optimizations.
- Documented changes in update log for future reference.
This commit is contained in:
2026-02-23 14:23:07 -05:00
parent 37a9584d0c
commit 5a2ad8ec1c
110 changed files with 10537 additions and 1185 deletions

View File

@@ -0,0 +1,78 @@
"""add playbooks, playbook_steps, saved_searches tables
Revision ID: b2c3d4e5f6a7
Revises: a1b2c3d4e5f6
Create Date: 2026-02-21 10:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "b2c3d4e5f6a7"
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add display_name to users table
with op.batch_alter_table("users") as batch_op:
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
# Create playbooks table
op.create_table(
"playbooks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("is_template", sa.Boolean(), server_default="0"),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("status", sa.String(20), server_default="active"),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
# Create playbook_steps table
op.create_table(
"playbook_steps",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
sa.Column("order_index", sa.Integer(), nullable=False),
sa.Column("title", sa.String(256), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("step_type", sa.String(32), server_default="manual"),
sa.Column("target_route", sa.String(256), nullable=True),
sa.Column("is_completed", sa.Boolean(), server_default="0"),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("notes", sa.Text(), nullable=True),
)
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
# Create saved_searches table
op.create_table(
"saved_searches",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("name", sa.String(256), nullable=False, index=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("search_type", sa.String(32), nullable=False),
sa.Column("query_params", sa.JSON(), nullable=False),
sa.Column("threshold", sa.Float(), nullable=True),
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_result_count", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
def downgrade() -> None:
op.drop_table("saved_searches")
op.drop_table("playbook_steps")
op.drop_table("playbooks")
with op.batch_alter_table("users") as batch_op:
batch_op.drop_column("display_name")

View File

@@ -0,0 +1,48 @@
"""add processing_tasks table
Revision ID: c3d4e5f6a7b8
Revises: b2c3d4e5f6a7
Create Date: 2026-02-22 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
revision: str = "c3d4e5f6a7b8"
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"processing_tasks",
sa.Column("id", sa.String(32), primary_key=True),
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
sa.Column("job_id", sa.String(64), nullable=True),
sa.Column("stage", sa.String(64), nullable=False),
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("message", sa.Text(), nullable=True),
sa.Column("error", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
)
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
def downgrade() -> None:
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
op.drop_table("processing_tasks")

View File

@@ -1,16 +1,18 @@
"""Analyst-assist agent module for ThreatHunt.
"""Analyst-assist agent module for ThreatHunt.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only and do not execute actions or modify data.
"""
from .core import ThreatHuntAgent
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider
from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
__all__ = [
"ThreatHuntAgent",
"LLMProvider",
"LocalProvider",
"NetworkedProvider",
"OnlineProvider",
"AgentContext",
"AgentResponse",
"Perspective",
"OllamaProvider",
"OpenWebUIProvider",
"EmbeddingProvider",
]

View File

@@ -1,208 +0,0 @@
"""Core ThreatHunt analyst-assist agent.
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
Agents are advisory only - no execution, no alerts, no data modifications.
"""
import logging
from typing import Optional
from pydantic import BaseModel, Field
from .providers import LLMProvider, get_provider
logger = logging.getLogger(__name__)
class AgentContext(BaseModel):
"""Context for agent guidance requests."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
host_identifier: Optional[str] = Field(
None, description="Host name, IP, or identifier"
)
data_summary: Optional[str] = Field(
None, description="Brief description of uploaded data"
)
conversation_history: Optional[list[dict]] = Field(
default_factory=list, description="Previous messages in conversation"
)
class AgentResponse(BaseModel):
"""Response from analyst-assist agent."""
guidance: str = Field(..., description="Advisory guidance for analyst")
confidence: float = Field(
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
)
suggested_pivots: list[str] = Field(
default_factory=list, description="Suggested analytical directions"
)
suggested_filters: list[str] = Field(
default_factory=list, description="Suggested data filters or queries"
)
caveats: Optional[str] = Field(
None, description="Assumptions, limitations, or caveats"
)
reasoning: Optional[str] = Field(
None, description="Explanation of how guidance was generated"
)
class ThreatHuntAgent:
"""Analyst-assist agent for ThreatHunt.
Provides guidance on:
- Interpreting CSV artifact data
- Suggesting analytical pivots and filters
- Forming and testing hypotheses
Policy:
- Advisory guidance only (no execution)
- No database or schema changes
- No alert escalation
- Transparent reasoning
"""
def __init__(self, provider: Optional[LLMProvider] = None):
"""Initialize agent with LLM provider.
Args:
provider: LLM provider instance. If None, uses get_provider() with auto mode.
"""
if provider is None:
try:
provider = get_provider("auto")
except RuntimeError as e:
logger.warning(f"Could not initialize default provider: {e}")
provider = None
self.provider = provider
self.system_prompt = self._build_system_prompt()
def _build_system_prompt(self) -> str:
"""Build the system prompt that governs agent behavior."""
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
Your role:
- Interpret and explain CSV artifact data from Velociraptor
- Suggest analytical pivots, filters, and hypotheses
- Highlight anomalies, patterns, or points of interest
- Guide analysts without replacing their judgment
Your constraints:
- You ONLY provide guidance and suggestions
- You do NOT execute actions or tools
- You do NOT modify data or escalate alerts
- You do NOT make autonomous decisions
- You ONLY analyze data presented to you
- You explain your reasoning transparently
- You acknowledge limitations and assumptions
- You suggest next investigative steps
When responding:
1. Start with a clear, direct answer to the query
2. Explain your reasoning based on the data context provided
3. Suggest 2-4 analytical pivots the analyst might explore
4. Suggest 2-4 data filters or queries that might be useful
5. Include relevant caveats or assumptions
6. Be honest about what you cannot determine from the data
Remember: The analyst is the decision-maker. You are an assistant."""
async def assist(self, context: AgentContext) -> AgentResponse:
"""Provide guidance on artifact data and analysis.
Args:
context: Request context including query and data context.
Returns:
Guidance response with suggestions and reasoning.
Raises:
RuntimeError: If no provider is available.
"""
if not self.provider:
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
# Build prompt with context
prompt = self._build_prompt(context)
try:
# Get guidance from LLM provider
guidance = await self.provider.generate(prompt, max_tokens=1024)
# Parse response into structured format
response = self._parse_response(guidance, context)
logger.info(
f"Agent assisted with query: {context.query[:50]}... "
f"(dataset: {context.dataset_name})"
)
return response
except Exception as e:
logger.error(f"Error generating guidance: {e}")
raise
def _build_prompt(self, context: AgentContext) -> str:
"""Build the prompt for the LLM."""
prompt_parts = [
f"Analyst query: {context.query}",
]
if context.dataset_name:
prompt_parts.append(f"Dataset: {context.dataset_name}")
if context.artifact_type:
prompt_parts.append(f"Artifact type: {context.artifact_type}")
if context.host_identifier:
prompt_parts.append(f"Host: {context.host_identifier}")
if context.data_summary:
prompt_parts.append(f"Data summary: {context.data_summary}")
if context.conversation_history:
prompt_parts.append("\nConversation history:")
for msg in context.conversation_history[-5:]: # Last 5 messages for context
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
return "\n".join(prompt_parts)
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
"""Parse LLM response into structured format.
Note: This is a simplified parser. In production, use structured output
from the LLM (JSON mode, function calling, etc.) for better reliability.
"""
# For now, return a structured response based on the raw guidance
# In production, parse JSON or use structured output from LLM
return AgentResponse(
guidance=response_text,
confidence=0.8, # Placeholder
suggested_pivots=[
"Analyze temporal patterns",
"Cross-reference with known indicators",
"Examine outliers in the dataset",
"Compare with baseline behavior",
],
suggested_filters=[
"Filter by high-risk indicators",
"Sort by timestamp for timeline analysis",
"Group by host or user",
"Filter by anomaly score",
],
caveats="Guidance is based on available data context. "
"Analysts should verify findings with additional sources.",
reasoning="Analysis generated based on artifact data patterns and analyst query.",
)

View File

@@ -1,190 +0,0 @@
"""Pluggable LLM provider interface for analyst-assist agents.
Supports three provider types:
- Local: On-device or on-prem models
- Networked: Shared internal inference services
- Online: External hosted APIs
"""
import os
from abc import ABC, abstractmethod
from typing import Optional
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
@abstractmethod
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate a response from the LLM.
Args:
prompt: The input prompt
max_tokens: Maximum tokens in response
Returns:
Generated text response
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if provider backend is available."""
pass
class LocalProvider(LLMProvider):
"""Local LLM provider (on-device or on-prem models)."""
def __init__(self, model_path: Optional[str] = None):
"""Initialize local provider.
Args:
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
"""
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
self.model = None
def is_available(self) -> bool:
"""Check if local model is available."""
if not self.model_path:
return False
# In production, would verify model file exists and can be loaded
return os.path.exists(str(self.model_path))
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using local model.
Note: This is a placeholder. In production, integrate with:
- llama-cpp-python for GGML models
- Ollama API
- vLLM
- Other local inference engines
"""
if not self.is_available():
raise RuntimeError("Local model not available")
# Placeholder implementation
return f"[Local model response to: {prompt[:50]}...]"
class NetworkedProvider(LLMProvider):
"""Networked LLM provider (shared internal inference services)."""
def __init__(
self,
api_endpoint: Optional[str] = None,
api_key: Optional[str] = None,
model_name: str = "default",
):
"""Initialize networked provider.
Args:
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
model_name: Model name/ID on the service.
"""
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
self.model_name = model_name
def is_available(self) -> bool:
"""Check if networked service is available."""
return bool(self.api_endpoint)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using networked service.
Note: This is a placeholder. In production, integrate with:
- Internal inference service API
- LLM inference container cluster
- Enterprise inference gateway
"""
if not self.is_available():
raise RuntimeError("Networked service not available")
# Placeholder implementation
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
class OnlineProvider(LLMProvider):
"""Online LLM provider (external hosted APIs)."""
def __init__(
self,
api_provider: str = "openai",
api_key: Optional[str] = None,
model_name: Optional[str] = None,
):
"""Initialize online provider.
Args:
api_provider: Provider name (openai, anthropic, google, etc.)
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
"""
self.api_provider = api_provider
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
self.model_name = model_name or os.getenv(
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
)
def is_available(self) -> bool:
"""Check if online API is available."""
return bool(self.api_key)
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
"""Generate response using online API.
Note: This is a placeholder. In production, integrate with:
- OpenAI API (GPT-3.5, GPT-4, etc.)
- Anthropic Claude API
- Google Gemini API
- Other hosted LLM services
"""
if not self.is_available():
raise RuntimeError("Online API not available or API key not set")
# Placeholder implementation
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
def get_provider(provider_type: str = "auto") -> LLMProvider:
"""Get an LLM provider based on configuration.
Args:
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
'auto' attempts to use the first available provider in order:
local -> networked -> online.
Returns:
Configured LLM provider instance.
Raises:
RuntimeError: If no provider is available.
"""
# Explicit provider selection
if provider_type == "local":
provider = LocalProvider()
elif provider_type == "networked":
provider = NetworkedProvider()
elif provider_type == "online":
provider = OnlineProvider()
elif provider_type == "auto":
# Try providers in order of preference
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
provider = Provider()
if provider.is_available():
return provider
raise RuntimeError(
"No LLM provider available. Configure at least one of: "
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
"or THREAT_HUNT_ONLINE_API_KEY"
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
if not provider.is_available():
raise RuntimeError(f"{provider_type} provider not available")
return provider

View File

@@ -1,170 +0,0 @@
"""API routes for analyst-assist agent."""
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
from app.agents.config import AgentConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/agent", tags=["agent"])
# Global agent instance (lazy-loaded)
_agent: ThreatHuntAgent | None = None
def get_agent() -> ThreatHuntAgent:
"""Get or create the agent instance."""
global _agent
if _agent is None:
if not AgentConfig.is_agent_enabled():
raise HTTPException(
status_code=503,
detail="Analyst-assist agent is not configured. "
"Please configure an LLM provider.",
)
_agent = ThreatHuntAgent()
return _agent
class AssistRequest(BaseModel):
"""Request for agent assistance."""
query: str = Field(
..., description="Analyst question or request for guidance"
)
dataset_name: str | None = Field(
None, description="Name of CSV dataset being analyzed"
)
artifact_type: str | None = Field(
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
)
host_identifier: str | None = Field(
None, description="Host name, IP address, or identifier"
)
data_summary: str | None = Field(
None, description="Brief summary or context about the uploaded data"
)
conversation_history: list[dict] | None = Field(
None, description="Previous messages for context"
)
class AssistResponse(BaseModel):
"""Response with agent guidance."""
guidance: str
confidence: float
suggested_pivots: list[str]
suggested_filters: list[str]
caveats: str | None = None
reasoning: str | None = None
@router.post(
"/assist",
response_model=AssistResponse,
summary="Get analyst-assist guidance",
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
"Agent provides advisory guidance only - no execution.",
)
async def agent_assist(request: AssistRequest) -> AssistResponse:
"""Provide analyst-assist guidance on artifact data.
The agent will:
- Explain and interpret the provided data context
- Suggest analytical pivots the analyst might explore
- Suggest data filters or queries that might be useful
- Highlight assumptions, limitations, and caveats
The agent will NOT:
- Execute any tools or actions
- Escalate findings to alerts
- Modify any data or schema
- Make autonomous decisions
Args:
request: Assistance request with query and context
Returns:
Guidance response with suggestions and reasoning
Raises:
HTTPException: If agent is not configured (503) or request fails
"""
try:
agent = get_agent()
# Build context
context = AgentContext(
query=request.query,
dataset_name=request.dataset_name,
artifact_type=request.artifact_type,
host_identifier=request.host_identifier,
data_summary=request.data_summary,
conversation_history=request.conversation_history or [],
)
# Get guidance
response = await agent.assist(context)
logger.info(
f"Agent assisted analyst with query: {request.query[:50]}... "
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
)
return AssistResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
)
except RuntimeError as e:
logger.error(f"Agent error: {e}")
raise HTTPException(
status_code=503,
detail=f"Agent unavailable: {str(e)}",
)
except Exception as e:
logger.exception(f"Unexpected error in agent_assist: {e}")
raise HTTPException(
status_code=500,
detail="Error generating guidance. Please try again.",
)
@router.get(
"/health",
summary="Check agent health",
description="Check if agent is configured and ready to assist.",
)
async def agent_health() -> dict:
"""Check agent availability and configuration.
Returns:
Health status with configuration details
"""
try:
agent = get_agent()
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
return {
"status": "healthy",
"provider": provider_type,
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
}
except HTTPException:
return {
"status": "unavailable",
"reason": "No LLM provider configured",
"configured_providers": {
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
"online": bool(AgentConfig.ONLINE_API_KEY),
},
}

View File

@@ -1,4 +1,4 @@
"""API routes for analyst-assist agent v2.
"""API routes for analyst-assist agent v2.
Supports quick, deep, and debate modes with streaming.
Conversations are persisted to the database.
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
import json
import logging
import re
import time
from collections import Counter
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.db.models import Conversation, Message
from app.db.models import Conversation, Message, Dataset, KeywordTheme
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
from app.agents.providers_v2 import check_all_nodes
from app.agents.registry import registry
from app.services.sans_rag import sans_rag
from app.services.scanner import KeywordScanner
logger = logging.getLogger(__name__)
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
return _agent
# ── Request / Response models ─────────────────────────────────────────
# Request / Response models
class AssistRequest(BaseModel):
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
model_override: str | None = None
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
hunt_id: str | None = None
execution_preference: str = Field(default="auto", description="auto | force | off")
learning_mode: bool = False
class AssistResponseModel(BaseModel):
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
node_used: str = ""
latency_ms: int = 0
perspectives: list[dict] | None = None
execution: dict | None = None
conversation_id: str | None = None
# ── Routes ────────────────────────────────────────────────────────────
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
POLICY_QUERY_TERMS = {
"policy", "violating", "violation", "browser history", "web history",
"domain", "domains", "adult", "gambling", "piracy", "aup",
}
WEB_DATASET_HINTS = {
"web", "history", "browser", "url", "visited_url", "domain", "title",
}
def _is_policy_domain_query(query: str) -> bool:
q = (query or "").lower()
if not q:
return False
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
def _should_execute_policy_scan(request: AssistRequest) -> bool:
pref = (request.execution_preference or "auto").strip().lower()
if pref == "off":
return False
if pref == "force":
return True
return _is_policy_domain_query(request.query)
def _extract_domain(value: str | None) -> str | None:
if not value:
return None
text = value.strip()
if not text:
return None
try:
parsed = urlparse(text)
if parsed.netloc:
return parsed.netloc.lower()
except Exception:
pass
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
return m.group(0).lower() if m else None
def _dataset_score(ds: Dataset) -> int:
score = 0
name = (ds.name or "").lower()
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
for h in WEB_DATASET_HINTS:
if h in name:
score += 2
if h in cols_l:
score += 3
if h in norm_vals_l:
score += 3
if "visited_url" in cols_l or "url" in cols_l:
score += 8
if "user" in cols_l or "username" in cols_l:
score += 2
if "clientid" in cols_l or "fqdn" in cols_l:
score += 2
if (ds.row_count or 0) > 0:
score += 1
return score
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
scanner = KeywordScanner(db)
theme_result = await db.execute(
select(KeywordTheme).where(
KeywordTheme.enabled == True, # noqa: E712
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
)
)
themes = list(theme_result.scalars().all())
theme_ids = [t.id for t in themes]
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
if request.hunt_id:
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
ds_result = await db.execute(ds_query)
candidates = list(ds_result.scalars().all())
if request.dataset_name:
needle = request.dataset_name.lower().strip()
candidates = [d for d in candidates if needle in (d.name or "").lower()]
scored = sorted(
((d, _dataset_score(d)) for d in candidates),
key=lambda x: x[1],
reverse=True,
)
selected = [d for d, s in scored if s > 0][:8]
dataset_ids = [d.id for d in selected]
if not dataset_ids:
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": 0,
"dataset_names": [],
"total_hits": 0,
"policy_hits": 0,
"top_user_hosts": [],
"top_domains": [],
"sample_hits": [],
"note": "No suitable browser/web-history datasets found in current scope.",
}
result = await scanner.scan(
dataset_ids=dataset_ids,
theme_ids=theme_ids or None,
scan_hunts=False,
scan_annotations=False,
scan_messages=False,
)
hits = result.get("hits", [])
user_host_counter = Counter()
domain_counter = Counter()
for h in hits:
user = h.get("username") or "(unknown-user)"
host = h.get("hostname") or "(unknown-host)"
user_host_counter[f"{user}|{host}"] += 1
dom = _extract_domain(h.get("matched_value"))
if dom:
domain_counter[dom] += 1
top_user_hosts = [
{"user_host": k, "count": v}
for k, v in user_host_counter.most_common(10)
]
top_domains = [
{"domain": k, "count": v}
for k, v in domain_counter.most_common(10)
]
return {
"mode": "policy_scan",
"themes": theme_names,
"datasets_scanned": len(dataset_ids),
"dataset_names": [d.name for d in selected],
"total_hits": int(result.get("total_hits", 0)),
"policy_hits": int(result.get("total_hits", 0)),
"rows_scanned": int(result.get("rows_scanned", 0)),
"top_user_hosts": top_user_hosts,
"top_domains": top_domains,
"sample_hits": hits[:20],
}
# Routes
@router.post(
@@ -84,6 +252,76 @@ async def agent_assist(
db: AsyncSession = Depends(get_db),
) -> AssistResponseModel:
try:
# Deterministic execution mode for policy-domain investigations.
if _should_execute_policy_scan(request):
t0 = time.monotonic()
exec_payload = await _run_policy_domain_execution(request, db)
latency_ms = int((time.monotonic() - t0) * 1000)
policy_hits = exec_payload.get("policy_hits", 0)
datasets_scanned = exec_payload.get("datasets_scanned", 0)
if policy_hits > 0:
guidance = (
f"Policy-violation scan complete: {policy_hits} hits across "
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
f"in execution results for triage."
)
confidence = 0.95
caveats = "Keyword-based matching can include false positives; validate with full URL context."
else:
guidance = (
f"No policy-violation hits found in current scope "
f"({datasets_scanned} dataset(s) scanned)."
)
confidence = 0.9
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
response = AssistResponseModel(
guidance=guidance,
confidence=confidence,
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
suggested_filters=[
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
"username != null",
"hostname != null",
],
caveats=caveats,
reasoning=(
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
if _is_policy_domain_query(request.query)
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
),
sans_references=["SANS FOR508", "SANS SEC504"],
model_used="execution:keyword_scanner",
node_used="local",
latency_ms=latency_ms,
execution=exec_payload,
)
conv_id = request.conversation_id
if conv_id or request.hunt_id:
conv_id = await _persist_conversation(
db,
conv_id,
request,
AgentResponse(
guidance=response.guidance,
confidence=response.confidence,
suggested_pivots=response.suggested_pivots,
suggested_filters=response.suggested_filters,
caveats=response.caveats,
reasoning=response.reasoning,
sans_references=response.sans_references,
model_used=response.model_used,
node_used=response.node_used,
latency_ms=response.latency_ms,
),
)
response.conversation_id = conv_id
return response
agent = get_agent()
context = AgentContext(
query=request.query,
@@ -97,6 +335,7 @@ async def agent_assist(
enrichment_summary=request.enrichment_summary,
mode=request.mode,
model_override=request.model_override,
learning_mode=request.learning_mode,
)
response = await agent.assist(context)
@@ -129,6 +368,7 @@ async def agent_assist(
}
for p in response.perspectives
] if response.perspectives else None,
execution=None,
conversation_id=conv_id,
)
@@ -208,7 +448,7 @@ async def list_models():
}
# ── Conversation persistence ──────────────────────────────────────────
# Conversation persistence
async def _persist_conversation(
@@ -263,3 +503,4 @@ async def _persist_conversation(
await db.flush()
return conv.id

View File

@@ -381,6 +381,10 @@ async def submit_job(
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
)
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
if not job_queue.can_accept():
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
job = job_queue.submit(jt, **params)
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}

View File

@@ -1,4 +1,4 @@
"""API routes for authentication register, login, refresh, profile."""
"""API routes for authentication — register, login, refresh, profile."""
import logging
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/auth", tags=["auth"])
# ── Request / Response models ─────────────────────────────────────────
# ── Request / Response models ─────────────────────────────────────────
class RegisterRequest(BaseModel):
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
tokens: TokenPair
# ── Routes ────────────────────────────────────────────────────────────
# ── Routes ────────────────────────────────────────────────────────────
@router.post(
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
user = User(
username=body.username,
email=body.email,
password_hash=hash_password(body.password),
hashed_password=hash_password(body.password),
display_name=body.display_name or body.username,
role="analyst", # Default role
)
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.username == body.username))
user = result.scalar_one_or_none()
if not user or not user.password_hash:
if not user or not user.hashed_password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
)
if not verify_password(body.password, user.password_hash):
if not verify_password(body.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid username or password",
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
if token_data.type != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type use refresh token",
detail="Invalid token type — use refresh token",
)
result = await db.execute(select(User).where(User.id == token_data.sub))
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
is_active=user.is_active,
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
)

View File

@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.db.models import ProcessingTask
from app.db.repositories.datasets import DatasetRepository
from app.services.csv_parser import parse_csv_bytes, infer_column_types
from app.services.normalizer import (
@@ -18,15 +19,20 @@ from app.services.normalizer import (
detect_ioc_columns,
detect_time_range,
)
from app.services.artifact_classifier import classify_artifact, get_artifact_category
logger = logging.getLogger(__name__)
from app.services.job_queue import job_queue, JobType
from app.services.host_inventory import inventory_cache
from app.services.scanner import keyword_scan_cache
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
# ── Response models ───────────────────────────────────────────────────
# -- Response models --
class DatasetSummary(BaseModel):
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
delimiter: str | None = None
time_range_start: str | None = None
time_range_end: str | None = None
artifact_type: str | None = None
processing_status: str | None = None
hunt_id: str | None = None
created_at: str
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
column_types: dict
normalized_columns: dict
ioc_columns: dict
artifact_type: str | None = None
processing_status: str
jobs_queued: list[str]
message: str
# ── Routes ────────────────────────────────────────────────────────────
# -- Routes --
@router.post(
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
response_model=UploadResponse,
summary="Upload a CSV dataset",
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
"IOCs auto-detected, and rows stored in the database.",
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
)
async def upload_dataset(
file: UploadFile = File(...),
@@ -87,7 +98,7 @@ async def upload_dataset(
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
db: AsyncSession = Depends(get_db),
):
"""Upload and parse a CSV dataset."""
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
@@ -136,7 +147,12 @@ async def upload_dataset(
# Detect time range
time_start, time_end = detect_time_range(rows, column_mapping)
# Store in DB
# Classify artifact type from column headers
artifact_type = classify_artifact(columns)
artifact_category = get_artifact_category(artifact_type)
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
# Store in DB with processing_status = "processing"
repo = DatasetRepository(db)
dataset = await repo.create_dataset(
name=name or Path(file.filename).stem,
@@ -152,6 +168,8 @@ async def upload_dataset(
time_range_start=time_start,
time_range_end=time_end,
hunt_id=hunt_id,
artifact_type=artifact_type,
processing_status="processing",
)
await repo.bulk_insert_rows(
@@ -162,9 +180,88 @@ async def upload_dataset(
logger.info(
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
f"artifact={artifact_type}"
)
# -- Queue full processing pipeline --
jobs_queued = []
task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
return UploadResponse(
id=dataset.id,
name=dataset.name,
@@ -173,7 +270,10 @@ async def upload_dataset(
column_types=column_types,
normalized_columns=column_mapping,
ioc_columns=ioc_columns,
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
artifact_type=artifact_type,
processing_status="processing",
jobs_queued=jobs_queued,
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
)
@@ -208,6 +308,8 @@ async def list_datasets(
delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(),
)
@@ -244,6 +346,8 @@ async def get_dataset(
delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(),
)
@@ -292,4 +396,5 @@ async def delete_dataset(
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id}

View File

@@ -8,16 +8,15 @@ from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Hunt, Conversation, Message
from app.db.models import Hunt, Dataset, ProcessingTask
from app.services.job_queue import job_queue
from app.services.host_inventory import inventory_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
# ── Models ────────────────────────────────────────────────────────────
class HuntCreate(BaseModel):
name: str = Field(..., max_length=256)
description: str | None = None
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
class HuntUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None # active | closed | archived
status: str | None = None
class HuntResponse(BaseModel):
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
total: int
# ── Routes ────────────────────────────────────────────────────────────
class HuntProgressResponse(BaseModel):
hunt_id: str
status: str
progress_percent: float
dataset_total: int
dataset_completed: int
dataset_processing: int
dataset_errors: int
active_jobs: int
queued_jobs: int
network_status: str
stages: dict
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
)
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
hunt = await db.get(Hunt, hunt_id)
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
ds_rows = await db.execute(
select(Dataset.id, Dataset.processing_status)
.where(Dataset.hunt_id == hunt_id)
)
datasets = ds_rows.all()
dataset_ids = {row[0] for row in datasets}
dataset_total = len(datasets)
dataset_completed = sum(1 for _, st in datasets if st == "completed")
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
jobs = job_queue.list_jobs(limit=5000)
relevant_jobs = [
j for j in jobs
if j.get("params", {}).get("hunt_id") == hunt_id
or j.get("params", {}).get("dataset_id") in dataset_ids
]
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
task_rows = await db.execute(
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
.where(ProcessingTask.hunt_id == hunt_id)
)
tasks = task_rows.all()
task_total = len(tasks)
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
task_running = sum(1 for _, st, _ in tasks if st == "running")
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
task_ratio = (task_done / task_total) if task_total > 0 else None
active_jobs = max(active_jobs_mem, task_running)
queued_jobs = max(queued_jobs_mem, task_queued)
stage_rollup: dict[str, dict] = {}
for stage, status, progress in tasks:
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
bucket["total"] += 1
if status in ("completed", "failed", "cancelled"):
bucket["done"] += 1
elif status == "running":
bucket["running"] += 1
elif status == "queued":
bucket["queued"] += 1
bucket["progress_sum"] += float(progress or 0.0)
for stage_name, bucket in stage_rollup.items():
total = max(1, bucket["total"])
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
if inventory_cache.get(hunt_id) is not None:
network_status = "ready"
network_ratio = 1.0
elif inventory_cache.is_building(hunt_id):
network_status = "building"
network_ratio = 0.5
else:
network_status = "none"
network_ratio = 0.0
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
if task_ratio is None:
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
else:
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
progress_percent = round(overall_ratio * 100.0, 1)
status = "ready"
if dataset_total == 0:
status = "idle"
elif progress_percent < 100:
status = "processing"
stages = {
"datasets": {
"total": dataset_total,
"completed": dataset_completed,
"processing": dataset_processing,
"errors": dataset_errors,
"percent": round(dataset_ratio * 100.0, 1),
},
"network": {
"status": network_status,
"percent": round(network_ratio * 100.0, 1),
},
"jobs": {
"active": active_jobs,
"queued": queued_jobs,
"total_seen": len(relevant_jobs),
"task_total": task_total,
"task_done": task_done,
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
},
"task_stages": stage_rollup,
}
return HuntProgressResponse(
hunt_id=hunt_id,
status=status,
progress_percent=progress_percent,
dataset_total=dataset_total,
dataset_completed=dataset_completed,
dataset_processing=dataset_processing,
dataset_errors=dataset_errors,
active_jobs=active_jobs,
queued_jobs=queued_jobs,
network_status=network_status,
stages=stages,
)
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
async def update_hunt(
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)

View File

@@ -1,25 +1,21 @@
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import select, func, delete
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import KeywordTheme, Keyword
from app.services.scanner import KeywordScanner
from app.services.scanner import KeywordScanner, keyword_scan_cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
# ── Pydantic schemas ──────────────────────────────────────────────────
class ThemeCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=128)
color: str = Field(default="#9e9e9e", max_length=16)
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
class ScanRequest(BaseModel):
dataset_ids: list[str] | None = None # None → all datasets
theme_ids: list[str] | None = None # None → all enabled themes
scan_hunts: bool = True
scan_annotations: bool = True
scan_messages: bool = True
dataset_ids: list[str] | None = None
theme_ids: list[str] | None = None
scan_hunts: bool = False
scan_annotations: bool = False
scan_messages: bool = False
prefer_cache: bool = True
force_rescan: bool = False
class ScanHit(BaseModel):
theme_name: str
theme_color: str
keyword: str
source_type: str # dataset_row | hunt | annotation | message
source_type: str
source_id: str | int
field: str
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
class ScanResponse(BaseModel):
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
themes_scanned: int
keywords_scanned: int
rows_scanned: int
# ── Helpers ───────────────────────────────────────────────────────────
cache_used: bool = False
cache_status: str = "miss"
cached_at: str | None = None
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
)
# ── Theme CRUD ────────────────────────────────────────────────────────
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
hits: list[dict] = []
total_rows = 0
cached_at: str | None = None
for entry in entries:
result = entry["result"]
total_rows += int(result.get("rows_scanned", 0) or 0)
if entry.get("built_at"):
if not cached_at or entry["built_at"] > cached_at:
cached_at = entry["built_at"]
for h in result.get("hits", []):
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
continue
hits.append(h)
return {
"total_hits": len(hits),
"hits": hits,
"rows_scanned": total_rows,
"cached_at": cached_at,
}
@router.get("/themes", response_model=ThemeListResponse)
async def list_themes(db: AsyncSession = Depends(get_db)):
"""List all keyword themes with their keywords."""
result = await db.execute(
select(KeywordTheme).order_by(KeywordTheme.name)
)
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
themes = result.scalars().all()
return ThemeListResponse(
themes=[_theme_to_out(t) for t in themes],
total=len(themes),
)
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
@router.post("/themes", response_model=ThemeOut, status_code=201)
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
"""Create a new keyword theme."""
exists = await db.scalar(
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
)
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
if exists:
raise HTTPException(409, f"Theme '{body.name}' already exists")
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
db.add(theme)
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.put("/themes/{theme_id}", response_model=ThemeOut)
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
"""Update theme name, color, or enabled status."""
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
if body.name is not None:
# check uniqueness
dup = await db.scalar(
select(KeywordTheme.id).where(
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
)
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
)
if dup:
raise HTTPException(409, f"Theme '{body.name}' already exists")
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
theme.enabled = body.enabled
await db.flush()
await db.refresh(theme)
keyword_scan_cache.clear()
return _theme_to_out(theme)
@router.delete("/themes/{theme_id}", status_code=204)
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
"""Delete a theme and all its keywords."""
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
await db.delete(theme)
# ── Keyword CRUD ──────────────────────────────────────────────────────
keyword_scan_cache.clear()
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
"""Add a single keyword to a theme."""
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
db.add(kw)
await db.flush()
await db.refresh(kw)
keyword_scan_cache.clear()
return KeywordOut(
id=kw.id, theme_id=kw.theme_id, value=kw.value,
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
"""Add multiple keywords to a theme at once."""
theme = await db.get(KeywordTheme, theme_id)
if not theme:
raise HTTPException(404, "Theme not found")
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
added += 1
await db.flush()
keyword_scan_cache.clear()
return {"added": added, "theme_id": theme_id}
@router.delete("/keywords/{keyword_id}", status_code=204)
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
"""Delete a single keyword."""
kw = await db.get(Keyword, keyword_id)
if not kw:
raise HTTPException(404, "Keyword not found")
await db.delete(kw)
# ── Scan endpoints ────────────────────────────────────────────────────
keyword_scan_cache.clear()
@router.post("/scan", response_model=ScanResponse)
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
"""Run AUP keyword scan across selected data sources."""
scanner = KeywordScanner(db)
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
return {
"total_hits": 0,
"hits": [],
"themes_scanned": 0,
"keywords_scanned": 0,
"rows_scanned": 0,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
can_use_cache = (
body.prefer_cache
and not body.force_rescan
and bool(body.dataset_ids)
and not body.scan_hunts
and not body.scan_annotations
and not body.scan_messages
)
if can_use_cache:
themes = await scanner._load_themes(body.theme_ids)
allowed_theme_names = {t.name for t in themes}
keywords_scanned = sum(len(theme.keywords) for theme in themes)
cached_entries: list[dict] = []
missing: list[str] = []
for dataset_id in (body.dataset_ids or []):
entry = keyword_scan_cache.get(dataset_id)
if not entry:
missing.append(dataset_id)
continue
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
if not missing and cached_entries:
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": True,
"cache_status": "hit",
"cached_at": merged["cached_at"],
}
if missing:
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
merged = _merge_cached_results(
cached_entries + [{"result": partial, "built_at": None}],
allowed_theme_names if body.theme_ids else None,
)
return {
"total_hits": merged["total_hits"],
"hits": merged["hits"],
"themes_scanned": len(themes),
"keywords_scanned": keywords_scanned,
"rows_scanned": merged["rows_scanned"],
"cache_used": len(cached_entries) > 0,
"cache_status": "partial" if cached_entries else "miss",
"cached_at": merged["cached_at"],
}
result = await scanner.scan(
dataset_ids=body.dataset_ids,
theme_ids=body.theme_ids,
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
scan_annotations=body.scan_annotations,
scan_messages=body.scan_messages,
)
return result
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}
@router.get("/scan/quick", response_model=ScanResponse)
@@ -251,7 +326,22 @@ async def quick_scan(
dataset_id: str = Query(..., description="Dataset to scan"),
db: AsyncSession = Depends(get_db),
):
"""Quick scan a single dataset with all enabled themes."""
entry = keyword_scan_cache.get(dataset_id)
if entry is not None:
result = entry.result
return {
**result,
"cache_used": True,
"cache_status": "hit",
"cached_at": entry.built_at,
}
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
return result
keyword_scan_cache.put(dataset_id, result)
return {
**result,
"cache_used": False,
"cache_status": "miss",
"cached_at": None,
}

View File

@@ -0,0 +1,146 @@
"""API routes for MITRE ATT&CK coverage visualization."""
import logging
from collections import defaultdict
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
# Canonical MITRE ATT&CK tactics in kill-chain order
TACTICS = [
"Reconnaissance", "Resource Development", "Initial Access",
"Execution", "Persistence", "Privilege Escalation",
"Defense Evasion", "Credential Access", "Discovery",
"Lateral Movement", "Collection", "Command and Control",
"Exfiltration", "Impact",
]
# Simplified technique-to-tactic mapping (top techniques)
TECHNIQUE_TACTIC: dict[str, str] = {
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
"T1053": "Persistence", "T1053.005": "Persistence",
"T1547": "Persistence", "T1547.001": "Persistence",
"T1543": "Persistence", "T1543.003": "Persistence",
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
"T1070.004": "Defense Evasion",
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
"T1003": "Credential Access", "T1003.001": "Credential Access",
"T1110": "Credential Access", "T1558": "Credential Access",
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
"T1570": "Lateral Movement",
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
"T1071": "Command and Control", "T1071.001": "Command and Control",
"T1105": "Command and Control", "T1572": "Command and Control",
"T1095": "Command and Control",
"T1048": "Exfiltration", "T1041": "Exfiltration",
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
"T1566": "Initial Access", "T1566.001": "Initial Access",
"T1566.002": "Initial Access",
"T1190": "Initial Access", "T1133": "Initial Access",
"T1195": "Initial Access", "T1195.002": "Initial Access",
}
def _get_tactic(technique_id: str) -> str:
"""Map a technique ID to its tactic."""
tech = technique_id.strip().upper()
if tech in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[tech]
# Try parent technique
if "." in tech:
parent = tech.split(".")[0]
if parent in TECHNIQUE_TACTIC:
return TECHNIQUE_TACTIC[parent]
return "Unknown"
@router.get("/coverage")
async def get_mitre_coverage(
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
techniques: dict[str, dict] = {}
# Collect from triage results
triage_q = select(TriageResult)
if hunt_id:
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
result = await db.execute(triage_q.limit(500))
for t in result.scalars().all():
for tech in (t.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
# Collect from host profiles
profile_q = select(HostProfile)
if hunt_id:
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
result = await db.execute(profile_q.limit(200))
for p in result.scalars().all():
for tech in (p.mitre_techniques or []):
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
# Collect from hypotheses
hyp_q = select(Hypothesis)
if hunt_id:
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
result = await db.execute(hyp_q.limit(200))
for h in result.scalars().all():
tech = h.mitre_technique
if tech:
if tech not in techniques:
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
techniques[tech]["count"] += 1
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
# Build tactic-grouped response
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
tactic_groups["Unknown"] = []
for tech in techniques.values():
tactic = tech["tactic"]
if tactic not in tactic_groups:
tactic_groups[tactic] = []
tactic_groups[tactic].append(tech)
total_techniques = len(techniques)
total_detections = sum(t["count"] for t in techniques.values())
return {
"tactics": TACTICS,
"technique_count": total_techniques,
"detection_count": total_detections,
"tactic_coverage": {
t: {"techniques": techs, "count": len(techs)}
for t, techs in tactic_groups.items()
if techs
},
"all_techniques": list(techniques.values()),
}

View File

@@ -1,12 +1,15 @@
"""Network topology API - host inventory endpoint."""
"""Network topology API - host inventory endpoint with background caching."""
import logging
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.services.host_inventory import build_host_inventory
from app.services.host_inventory import build_host_inventory, inventory_cache
from app.services.job_queue import job_queue, JobType
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/network", tags=["network"])
@@ -15,14 +18,158 @@ router = APIRouter(prefix="/api/network", tags=["network"])
@router.get("/host-inventory")
async def get_host_inventory(
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
force: bool = Query(False, description="Force rebuild, ignoring cache"),
db: AsyncSession = Depends(get_db),
):
"""Build a deduplicated host inventory from all datasets in a hunt.
"""Return a deduplicated host inventory for the hunt.
Returns unique hosts with hostname, IPs, OS, logged-in users, and
network connections derived from netstat/connection data.
Returns instantly from cache if available (pre-built after upload or on startup).
If cache is cold, triggers a background build and returns 202 so the
frontend can poll /inventory-status and re-request when ready.
"""
result = await build_host_inventory(hunt_id, db)
if result["stats"]["total_hosts"] == 0:
return result
return result
# Force rebuild: invalidate cache, queue background job, return 202
if force:
inventory_cache.invalidate(hunt_id)
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Rebuild queued"},
)
# Try cache first
cached = inventory_cache.get(hunt_id)
if cached is not None:
logger.info(f"Serving cached host inventory for {hunt_id}")
return cached
# Cache miss: trigger background build instead of blocking for 90+ seconds
if not inventory_cache.is_building(hunt_id):
logger.info(f"Cache miss for {hunt_id}, triggering background build")
if job_queue.is_backlogged():
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(
status_code=202,
content={"status": "building", "message": "Inventory is being built in the background"},
)
def _build_summary(inv: dict, top_n: int = 20) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
return {
"stats": inv.get("stats", {}),
"top_hosts": [
{
"id": h.get("id"),
"hostname": h.get("hostname"),
"row_count": h.get("row_count", 0),
"ip_count": len(h.get("ips", [])),
"user_count": len(h.get("users", [])),
}
for h in top_hosts
],
"top_edges": top_edges,
}
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
hosts = inv.get("hosts", [])
conns = inv.get("connections", [])
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
if node_id:
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
ids = {node_id}
for c in rel_edges:
ids.add(c.get("source"))
ids.add(c.get("target"))
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
else:
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
allowed = {h.get("id") for h in rel_hosts}
rel_edges = [
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
if c.get("source") in allowed and c.get("target") in allowed
][:max_edges]
return {
"hosts": rel_hosts,
"connections": rel_edges,
"stats": {
**inv.get("stats", {}),
"subgraph_hosts": len(rel_hosts),
"subgraph_connections": len(rel_edges),
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
},
}
@router.get("/summary")
async def get_inventory_summary(
hunt_id: str = Query(..., description="Hunt ID"),
top_n: int = Query(20, ge=1, le=200),
):
"""Return a lightweight summary view for large hunts."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_summary(cached, top_n=top_n)
@router.get("/subgraph")
async def get_inventory_subgraph(
hunt_id: str = Query(..., description="Hunt ID"),
node_id: str | None = Query(None, description="Optional focal node"),
max_hosts: int = Query(200, ge=1, le=5000),
max_edges: int = Query(1500, ge=1, le=20000),
):
"""Return a bounded subgraph for scale-safe rendering."""
cached = inventory_cache.get(hunt_id)
if cached is None:
if not inventory_cache.is_building(hunt_id):
if job_queue.is_backlogged():
return JSONResponse(
status_code=202,
content={"status": "deferred", "message": "Queue busy, retry shortly"},
)
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return JSONResponse(status_code=202, content={"status": "building"})
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
@router.get("/inventory-status")
async def get_inventory_status(
hunt_id: str = Query(..., description="Hunt ID to check"),
):
"""Check whether pre-computed host inventory is ready for a hunt.
Returns: { status: "ready" | "building" | "none" }
"""
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
@router.post("/rebuild-inventory")
async def trigger_rebuild(
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
):
"""Trigger a background rebuild of the host inventory cache."""
inventory_cache.invalidate(hunt_id)
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
return {"job_id": job.id, "status": "queued"}

View File

@@ -0,0 +1,217 @@
"""API routes for investigation playbooks."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Playbook, PlaybookStep
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
# -- Request / Response schemas ---
class StepCreate(BaseModel):
title: str
description: str | None = None
step_type: str = "manual"
target_route: str | None = None
class PlaybookCreate(BaseModel):
name: str
description: str | None = None
hunt_id: str | None = None
is_template: bool = False
steps: list[StepCreate] = []
class PlaybookUpdate(BaseModel):
name: str | None = None
description: str | None = None
status: str | None = None
class StepUpdate(BaseModel):
is_completed: bool | None = None
notes: str | None = None
# -- Default investigation templates ---
DEFAULT_TEMPLATES = [
{
"name": "Standard Threat Hunt",
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
"steps": [
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
],
},
{
"name": "Incident Response Triage",
"description": "Fast-track workflow for active incident response.",
"steps": [
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
],
},
]
# -- Routes ---
@router.get("")
async def list_playbooks(
include_templates: bool = True,
hunt_id: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(Playbook)
if hunt_id:
q = q.where(Playbook.hunt_id == hunt_id)
if not include_templates:
q = q.where(Playbook.is_template == False)
q = q.order_by(Playbook.created_at.desc())
result = await db.execute(q.limit(100))
playbooks = result.scalars().all()
return {"playbooks": [
{
"id": p.id, "name": p.name, "description": p.description,
"is_template": p.is_template, "hunt_id": p.hunt_id,
"status": p.status,
"total_steps": len(p.steps),
"completed_steps": sum(1 for s in p.steps if s.is_completed),
"created_at": p.created_at.isoformat() if p.created_at else None,
}
for p in playbooks
]}
@router.get("/templates")
async def get_templates():
"""Return built-in investigation templates."""
return {"templates": DEFAULT_TEMPLATES}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
pb = Playbook(
name=body.name,
description=body.description,
hunt_id=body.hunt_id,
is_template=body.is_template,
)
db.add(pb)
await db.flush()
created_steps = []
for i, step in enumerate(body.steps):
s = PlaybookStep(
playbook_id=pb.id,
order_index=i,
title=step.title,
description=step.description,
step_type=step.step_type,
target_route=step.target_route,
)
db.add(s)
created_steps.append(s)
await db.flush()
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
"steps": [
{"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route, "is_completed": False}
for s in created_steps
],
}
@router.get("/{playbook_id}")
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
return {
"id": pb.id, "name": pb.name, "description": pb.description,
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
"status": pb.status,
"created_at": pb.created_at.isoformat() if pb.created_at else None,
"steps": [
{
"id": s.id, "order_index": s.order_index, "title": s.title,
"description": s.description, "step_type": s.step_type,
"target_route": s.target_route,
"is_completed": s.is_completed,
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
"notes": s.notes,
}
for s in sorted(pb.steps, key=lambda x: x.order_index)
],
}
@router.put("/{playbook_id}")
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
if body.name is not None:
pb.name = body.name
if body.description is not None:
pb.description = body.description
if body.status is not None:
pb.status = body.status
return {"status": "updated"}
@router.delete("/{playbook_id}")
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
pb = result.scalar_one_or_none()
if not pb:
raise HTTPException(status_code=404, detail="Playbook not found")
await db.delete(pb)
return {"status": "deleted"}
@router.put("/steps/{step_id}")
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
step = result.scalar_one_or_none()
if not step:
raise HTTPException(status_code=404, detail="Step not found")
if body.is_completed is not None:
step.is_completed = body.is_completed
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
if body.notes is not None:
step.notes = body.notes
return {"status": "updated", "is_completed": step.is_completed}

View File

@@ -0,0 +1,164 @@
"""API routes for saved searches and bookmarked queries."""
import logging
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import SavedSearch
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
class SearchCreate(BaseModel):
name: str
description: str | None = None
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
query_params: dict
threshold: float | None = None
class SearchUpdate(BaseModel):
name: str | None = None
description: str | None = None
query_params: dict | None = None
threshold: float | None = None
@router.get("")
async def list_searches(
search_type: str | None = None,
db: AsyncSession = Depends(get_db),
):
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
if search_type:
q = q.where(SavedSearch.search_type == search_type)
result = await db.execute(q.limit(100))
searches = result.scalars().all()
return {"searches": [
{
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
for s in searches
]}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
s = SavedSearch(
name=body.name,
description=body.description,
search_type=body.search_type,
query_params=body.query_params,
threshold=body.threshold,
)
db.add(s)
await db.flush()
return {
"id": s.id, "name": s.name, "search_type": s.search_type,
"query_params": s.query_params, "threshold": s.threshold,
}
@router.get("/{search_id}")
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
return {
"id": s.id, "name": s.name, "description": s.description,
"search_type": s.search_type, "query_params": s.query_params,
"threshold": s.threshold,
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
"last_result_count": s.last_result_count,
"created_at": s.created_at.isoformat() if s.created_at else None,
}
@router.put("/{search_id}")
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
if body.name is not None:
s.name = body.name
if body.description is not None:
s.description = body.description
if body.query_params is not None:
s.query_params = body.query_params
if body.threshold is not None:
s.threshold = body.threshold
return {"status": "updated"}
@router.delete("/{search_id}")
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
await db.delete(s)
return {"status": "deleted"}
@router.post("/{search_id}/run")
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
"""Execute a saved search and return results with delta from last run."""
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
s = result.scalar_one_or_none()
if not s:
raise HTTPException(status_code=404, detail="Saved search not found")
previous_count = s.last_result_count or 0
results = []
count = 0
if s.search_type == "ioc_search":
from app.db.models import EnrichmentResult
ioc_value = s.query_params.get("ioc_value", "")
if ioc_value:
q = select(EnrichmentResult).where(
EnrichmentResult.ioc_value.contains(ioc_value)
)
res = await db.execute(q.limit(100))
for er in res.scalars().all():
results.append({
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
"source": er.source, "verdict": er.verdict,
})
count = len(results)
elif s.search_type == "keyword_scan":
from app.db.models import KeywordTheme
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
themes = res.scalars().all()
count = sum(len(t.keywords) for t in themes)
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
# Update last run metadata
s.last_run_at = datetime.now(timezone.utc)
s.last_result_count = count
delta = count - previous_count
return {
"search_id": s.id, "search_name": s.name,
"search_type": s.search_type,
"result_count": count,
"previous_count": previous_count,
"delta": delta,
"results": results[:50],
}

View File

@@ -0,0 +1,184 @@
"""STIX 2.1 export endpoint.
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
STIX 2.1 Bundle JSON download. No external dependencies required we
build the JSON directly following the OASIS STIX 2.1 spec.
"""
import json
import uuid
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import (
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
EnrichmentResult, HuntReport,
)
router = APIRouter(prefix="/api/export", tags=["export"])
STIX_SPEC_VERSION = "2.1"
def _stix_id(stype: str) -> str:
return f"{stype}--{uuid.uuid4()}"
def _now_iso() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
def _build_identity(hunt_name: str) -> dict:
return {
"type": "identity",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("identity"),
"created": _now_iso(),
"modified": _now_iso(),
"name": f"ThreatHunt - {hunt_name}",
"identity_class": "system",
}
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
pattern_map = {
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
"domain": f"[domain-name:value = '{ioc_value}']",
"url": f"[url:value = '{ioc_value}']",
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
"email": f"[email-addr:value = '{ioc_value}']",
}
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
now = _now_iso()
return {
"type": "indicator",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("indicator"),
"created": now,
"modified": now,
"name": f"{ioc_type}: {ioc_value}",
"pattern": pattern,
"pattern_type": "stix",
"valid_from": now,
"created_by_ref": identity_id,
"labels": [verdict or "suspicious"],
}
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "attack-pattern",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("attack-pattern"),
"created": now,
"modified": now,
"name": technique_id,
"created_by_ref": identity_id,
"external_references": [
{
"source_name": "mitre-attack",
"external_id": technique_id,
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
}
],
}
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
now = _now_iso()
return {
"type": "report",
"spec_version": STIX_SPEC_VERSION,
"id": _stix_id("report"),
"created": now,
"modified": now,
"name": hyp.title,
"description": hyp.description or "",
"published": now,
"created_by_ref": identity_id,
"labels": ["threat-hunt-hypothesis"],
"object_refs": [],
}
@router.get("/stix/{hunt_id}")
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
# Fetch hunt
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
if not hunt:
raise HTTPException(404, "Hunt not found")
identity = _build_identity(hunt.name)
objects: list[dict] = [identity]
seen_techniques: set[str] = set()
seen_iocs: set[str] = set()
# Gather IOCs from enrichment results for hunt's datasets
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
ds_ids = [r[0] for r in datasets_q.all()]
if ds_ids:
enrichments = (await db.execute(
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
)).scalars().all()
for e in enrichments:
key = f"{e.ioc_type}:{e.ioc_value}"
if key not in seen_iocs:
seen_iocs.add(key)
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
# Gather techniques from triage results
triages = (await db.execute(
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
)).scalars().all()
for t in triages:
for tech in (t.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather techniques from host profiles
profiles = (await db.execute(
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
)).scalars().all()
for p in profiles:
for tech in (p.mitre_techniques or []):
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
if tid not in seen_techniques:
seen_techniques.add(tid)
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
# Gather hypotheses
hypos = (await db.execute(
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
)).scalars().all()
for h in hypos:
objects.append(_hypothesis_to_report(h, identity["id"]))
if h.mitre_technique and h.mitre_technique not in seen_techniques:
seen_techniques.add(h.mitre_technique)
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
bundle = {
"type": "bundle",
"id": _stix_id("bundle"),
"objects": objects,
}
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
return Response(
content=json.dumps(bundle, indent=2),
media_type="application/json",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)

View File

@@ -0,0 +1,128 @@
"""API routes for forensic timeline visualization."""
import logging
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_db
from app.db.models import Dataset, DatasetRow, Hunt
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
def _parse_timestamp(val: str | None) -> str | None:
"""Try to parse a timestamp string, return ISO format or None."""
if not val:
return None
val = str(val).strip()
if not val:
return None
# Try common formats
for fmt in [
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
]:
try:
return datetime.strptime(val, fmt).isoformat() + "Z"
except ValueError:
continue
return None
# Columns likely to contain timestamps
TIME_COLUMNS = {
"timestamp", "time", "datetime", "date", "created", "modified",
"eventtime", "event_time", "start_time", "end_time",
"lastmodified", "last_modified", "created_at", "updated_at",
"mtime", "atime", "ctime", "btime",
"timecreated", "timegenerated", "sourcetime",
}
@router.get("/hunt/{hunt_id}")
async def get_hunt_timeline(
hunt_id: str,
limit: int = 2000,
db: AsyncSession = Depends(get_db),
):
"""Build a timeline of events across all datasets in a hunt."""
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
hunt = result.scalar_one_or_none()
if not hunt:
raise HTTPException(status_code=404, detail="Hunt not found")
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
datasets = result.scalars().all()
if not datasets:
return {"hunt_id": hunt_id, "events": [], "datasets": []}
events = []
dataset_info = []
for ds in datasets:
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
dataset_info.append({
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
"row_count": ds.row_count,
})
# Find time columns for this dataset
schema = ds.column_schema or {}
time_cols = []
for col in (ds.normalized_columns or {}).values():
if col.lower() in TIME_COLUMNS:
time_cols.append(col)
if not time_cols:
for col in schema:
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
time_cols.append(col)
if not time_cols:
continue
# Fetch rows
rows_result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.order_by(DatasetRow.row_index)
.limit(limit // max(len(datasets), 1))
)
for r in rows_result.scalars().all():
data = r.normalized_data or r.data
ts = None
for tc in time_cols:
ts = _parse_timestamp(data.get(tc))
if ts:
break
if ts:
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
events.append({
"timestamp": ts,
"dataset_id": ds.id,
"dataset_name": ds.name,
"artifact_type": artifact_type,
"row_index": r.row_index,
"hostname": str(hostname)[:128],
"process": str(process)[:128],
"summary": str(summary)[:256],
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
})
# Sort by timestamp
events.sort(key=lambda e: e["timestamp"])
return {
"hunt_id": hunt_id,
"hunt_name": hunt.name,
"event_count": len(events),
"datasets": dataset_info,
"events": events[:limit],
}

View File

@@ -1,4 +1,4 @@
"""Application configuration single source of truth for all settings.
"""Application configuration - single source of truth for all settings.
Loads from environment variables with sensible defaults for local dev.
"""
@@ -13,12 +13,12 @@ from pydantic import Field
class AppConfig(BaseSettings):
"""Central configuration for the entire ThreatHunt application."""
# ── General ────────────────────────────────────────────────────────
# -- General --------------------------------------------------------
APP_NAME: str = "ThreatHunt"
APP_VERSION: str = "0.3.0"
DEBUG: bool = Field(default=False, description="Enable debug mode")
# ── Database ───────────────────────────────────────────────────────
# -- Database -------------------------------------------------------
DATABASE_URL: str = Field(
default="sqlite+aiosqlite:///./threathunt.db",
description="Async SQLAlchemy database URL. "
@@ -26,17 +26,17 @@ class AppConfig(BaseSettings):
"postgresql+asyncpg://user:pass@host/db for production.",
)
# ── CORS ───────────────────────────────────────────────────────────
# -- CORS -----------------------------------------------------------
ALLOWED_ORIGINS: str = Field(
default="http://localhost:3000,http://localhost:8000",
description="Comma-separated list of allowed CORS origins",
)
# ── File uploads ───────────────────────────────────────────────────
# -- File uploads ---------------------------------------------------
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
# ── LLM Cluster Wile & Roadrunner ────────────────────────────────
# -- LLM Cluster - Wile & Roadrunner --------------------------------
OPENWEBUI_URL: str = Field(
default="https://ai.guapo613.beer",
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
@@ -58,7 +58,7 @@ class AppConfig(BaseSettings):
default=11434, description="Ollama port on Roadrunner"
)
# ── LLM Routing defaults ──────────────────────────────────────────
# -- LLM Routing defaults ------------------------------------------
DEFAULT_FAST_MODEL: str = Field(
default="llama3.1:latest",
description="Default model for quick chat / simple queries",
@@ -80,18 +80,18 @@ class AppConfig(BaseSettings):
description="Default embedding model",
)
# ── Agent behaviour ───────────────────────────────────────────────
# -- Agent behaviour ------------------------------------------------
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
# ── Enrichment API keys ───────────────────────────────────────────
# -- Enrichment API keys --------------------------------------------
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
# ── Auth ──────────────────────────────────────────────────────────
# -- Auth -----------------------------------------------------------
JWT_SECRET: str = Field(
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
description="Secret for JWT signing",
@@ -99,6 +99,73 @@ class AppConfig(BaseSettings):
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
# -- Triage settings ------------------------------------------------
TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch")
TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field(
default=200, description="Stop triage after this many suspicious rows"
)
TRIAGE_ESCALATION_THRESHOLD: float = Field(
default=5.0, description="Risk score threshold for escalation counting"
)
# -- Host profiler settings -----------------------------------------
HOST_PROFILE_CONCURRENCY: int = Field(
default=3, description="Max concurrent host profile LLM calls"
)
# -- Scanner settings -----------------------------------------------
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
default=120000,
description="Global row budget for a single AUP scan request (0 = unlimited)",
)
# -- Job queue settings ----------------------------------------------
JOB_QUEUE_MAX_BACKLOG: int = Field(
default=2000, description="Soft cap for queued background jobs"
)
JOB_QUEUE_RETAIN_COMPLETED: int = Field(
default=3000, description="Maximum completed/failed jobs to retain in memory"
)
JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(
default=60, description="How often to run in-memory job cleanup"
)
JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(
default=3600, description="Age threshold for in-memory completed job cleanup"
)
# -- Startup throttling ------------------------------------------------
STARTUP_WARMUP_MAX_HUNTS: int = Field(
default=5, description="Max hunts to warm inventory cache for at startup"
)
STARTUP_REPROCESS_MAX_DATASETS: int = Field(
default=25, description="Max unprocessed datasets to enqueue at startup"
)
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
default=True,
description="Mark stale queued/running processing tasks as failed on startup",
)
# -- Network API scale guards -----------------------------------------
NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(
default=400, description="Hard cap for hosts returned by network subgraph endpoint"
)
NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
)
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
default=5000,
description="Row budget per dataset when building host inventory (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
default=120000,
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
)
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
default=120000,
description="Max unique connection tuples retained during host inventory build",
)
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
@property
@@ -119,3 +186,4 @@ class AppConfig(BaseSettings):
settings = AppConfig()

View File

@@ -21,9 +21,14 @@ _engine_kwargs: dict = dict(
)
if _is_sqlite:
_engine_kwargs["connect_args"] = {"timeout": 30}
_engine_kwargs["pool_size"] = 1
_engine_kwargs["max_overflow"] = 0
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
# NullPool: each session gets its own connection.
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
from sqlalchemy.pool import NullPool
_engine_kwargs["poolclass"] = NullPool
else:
_engine_kwargs["pool_size"] = 5
_engine_kwargs["max_overflow"] = 10
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
@@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record):
if _is_sqlite:
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA busy_timeout=5000")
cursor.execute("PRAGMA busy_timeout=30000")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
@@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker(
)
# Alias expected by other modules
async_session = async_session_factory
class Base(DeclarativeBase):
"""Base class for all ORM models."""
pass
@@ -71,5 +80,5 @@ async def init_db() -> None:
async def dispose_db() -> None:
"""Dispose of the engine connection pool."""
await engine.dispose()
"""Dispose of the engine on shutdown."""
await engine.dispose()

View File

@@ -1,4 +1,4 @@
"""SQLAlchemy ORM models for ThreatHunt.
"""SQLAlchemy ORM models for ThreatHunt.
All persistent entities: datasets, hunts, conversations, annotations,
hypotheses, enrichment results, users, and AI analysis tables.
@@ -43,6 +43,7 @@ class User(Base):
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
role: Mapped[str] = mapped_column(String(16), default="analyst")
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
@@ -399,4 +400,108 @@ class AnomalyResult(Base):
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
# -- Persistent Processing Tasks (Phase 2) ---
class ProcessingTask(Base):
__tablename__ = "processing_tasks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
)
dataset_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
)
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
progress: Mapped[float] = mapped_column(Float, default=0.0)
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
__table_args__ = (
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
)
# -- Playbook / Investigation Templates (Feature 3) ---
class Playbook(Base):
__tablename__ = "playbooks"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_by: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
is_template: Mapped[bool] = mapped_column(Boolean, default=False)
hunt_id: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("hunts.id"), nullable=True
)
status: Mapped[str] = mapped_column(String(20), default="active")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
)
steps: Mapped[list["PlaybookStep"]] = relationship(
back_populates="playbook", lazy="selectin", cascade="all, delete-orphan",
order_by="PlaybookStep.order_index",
)
class PlaybookStep(Base):
__tablename__ = "playbook_steps"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
playbook_id: Mapped[str] = mapped_column(
String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
)
order_index: Mapped[int] = mapped_column(Integer, nullable=False)
title: Mapped[str] = mapped_column(String(256), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
step_type: Mapped[str] = mapped_column(String(32), default="manual")
target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
is_completed: Mapped[bool] = mapped_column(Boolean, default=False)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
playbook: Mapped["Playbook"] = relationship(back_populates="steps")
__table_args__ = (
Index("ix_playbook_steps_playbook", "playbook_id"),
)
# -- Saved Searches (Feature 5) ---
class SavedSearch(Base):
__tablename__ = "saved_searches"
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
search_type: Mapped[str] = mapped_column(String(32), nullable=False)
query_params: Mapped[dict] = mapped_column(JSON, nullable=False)
threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
created_by: Mapped[Optional[str]] = mapped_column(
String(32), ForeignKey("users.id"), nullable=True
)
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
__table_args__ = (
Index("ix_saved_searches_type", "search_type"),
)

View File

@@ -25,6 +25,11 @@ from app.api.routes.auth import router as auth_router
from app.api.routes.keywords import router as keywords_router
from app.api.routes.analysis import router as analysis_router
from app.api.routes.network import router as network_router
from app.api.routes.mitre import router as mitre_router
from app.api.routes.timeline import router as timeline_router
from app.api.routes.playbooks import router as playbooks_router
from app.api.routes.saved_searches import router as searches_router
from app.api.routes.stix_export import router as stix_router
logger = logging.getLogger(__name__)
@@ -47,13 +52,80 @@ async def lifespan(app: FastAPI):
await seed_defaults(seed_db)
logger.info("AUP keyword defaults checked")
# Start job queue (Phase 10)
from app.services.job_queue import job_queue, register_all_handlers
# Start job queue
from app.services.job_queue import (
job_queue,
register_all_handlers,
reconcile_stale_processing_tasks,
JobType,
)
if settings.STARTUP_RECONCILE_STALE_TASKS:
reconciled = await reconcile_stale_processing_tasks()
if reconciled:
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
register_all_handlers()
await job_queue.start()
logger.info("Job queue started (%d workers)", job_queue._max_workers)
# Start load balancer health loop (Phase 10)
# Pre-warm host inventory cache for existing hunts
from app.services.host_inventory import inventory_cache
async with async_session_factory() as warm_db:
from sqlalchemy import select, func
from app.db.models import Hunt, Dataset
stmt = (
select(Hunt.id)
.join(Dataset, Dataset.hunt_id == Hunt.id)
.group_by(Hunt.id)
.having(func.count(Dataset.id) > 0)
)
result = await warm_db.execute(stmt)
hunt_ids = [row[0] for row in result.all()]
warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]
for hid in warm_hunts:
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)
if warm_hunts:
logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})")
# Check which datasets still need processing
# (no anomaly results = never fully processed)
async with async_session_factory() as reprocess_db:
from sqlalchemy import select, exists
from app.db.models import Dataset, AnomalyResult
# Find datasets that have zero anomaly results (pipeline never ran or failed)
has_anomaly = (
select(AnomalyResult.id)
.where(AnomalyResult.dataset_id == Dataset.id)
.limit(1)
.correlate(Dataset)
.exists()
)
stmt = select(Dataset.id).where(~has_anomaly)
result = await reprocess_db.execute(stmt)
unprocessed_ids = [row[0] for row in result.all()]
if unprocessed_ids:
to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]
for ds_id in to_reprocess:
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})")
async with async_session_factory() as update_db:
from sqlalchemy import update
from app.db.models import Dataset
await update_db.execute(
update(Dataset)
.where(Dataset.id.in_(to_reprocess))
.values(processing_status="processing")
)
await update_db.commit()
else:
logger.info("All datasets already processed - skipping startup pipeline")
# Start load balancer health loop
from app.services.load_balancer import lb
await lb.start_health_loop(interval=30.0)
logger.info("Load balancer health loop started")
@@ -61,12 +133,10 @@ async def lifespan(app: FastAPI):
yield
logger.info("Shutting down ...")
# Stop job queue
from app.services.job_queue import job_queue as jq
await jq.stop()
logger.info("Job queue stopped")
# Stop load balancer
from app.services.load_balancer import lb as _lb
await _lb.stop_health_loop()
logger.info("Load balancer stopped")
@@ -106,6 +176,11 @@ app.include_router(reports_router)
app.include_router(keywords_router)
app.include_router(analysis_router)
app.include_router(network_router)
app.include_router(mitre_router)
app.include_router(timeline_router)
app.include_router(playbooks_router)
app.include_router(searches_router)
app.include_router(stix_router)
@app.get("/", tags=["health"])
@@ -120,4 +195,4 @@ async def root():
"roadrunner": settings.roadrunner_url,
"openwebui": settings.OPENWEBUI_URL,
},
}
}

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Dataset, DatasetRow
from app.config import settings
logger = logging.getLogger(__name__)
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
return name or ''
# In-memory host inventory cache
# Pre-computed results stored per hunt_id, built in background after upload.
import time as _time
class _InventoryCache:
"""Simple in-memory cache for pre-computed host inventories."""
def __init__(self):
self._data: dict[str, dict] = {} # hunt_id -> result dict
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
self._building: set[str] = set() # hunt_ids currently being built
def get(self, hunt_id: str) -> dict | None:
"""Return cached result if present. Never expires; only invalidated on new upload."""
return self._data.get(hunt_id)
def put(self, hunt_id: str, result: dict):
self._data[hunt_id] = result
self._timestamps[hunt_id] = _time.time()
self._building.discard(hunt_id)
logger.info(f"Cached host inventory for hunt {hunt_id} "
f"({result['stats']['total_hosts']} hosts)")
def invalidate(self, hunt_id: str):
self._data.pop(hunt_id, None)
self._timestamps.pop(hunt_id, None)
def is_building(self, hunt_id: str) -> bool:
return hunt_id in self._building
def set_building(self, hunt_id: str):
self._building.add(hunt_id)
def clear_building(self, hunt_id: str):
self._building.discard(hunt_id)
def status(self, hunt_id: str) -> str:
if hunt_id in self._building:
return "building"
if hunt_id in self._data:
return "ready"
return "none"
inventory_cache = _InventoryCache()
def _infer_os(fqdn: str) -> str:
u = fqdn.upper()
if 'W10-' in u or 'WIN10' in u:
@@ -151,33 +201,61 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
}}
hosts: dict[str, dict] = {} # fqdn -> host record
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
connections: dict[tuple, int] = defaultdict(int)
total_rows = 0
ds_with_hosts = 0
sampled_dataset_count = 0
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
global_budget_reached = False
dropped_connections = 0
for ds in all_datasets:
if total_row_budget and total_rows >= total_row_budget:
global_budget_reached = True
break
cols = _identify_columns(ds)
if not cols['fqdn'] and not cols['host_id']:
continue
ds_with_hosts += 1
batch_size = 5000
offset = 0
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
rows_scanned_this_dataset = 0
sampled_dataset = False
last_row_index = -1
while True:
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
rr = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds.id)
.where(DatasetRow.row_index > last_row_index)
.order_by(DatasetRow.row_index)
.offset(offset).limit(batch_size)
.limit(batch_size)
)
rows = rr.scalars().all()
if not rows:
break
for ro in rows:
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
sampled_dataset = True
break
if total_row_budget and total_rows >= total_row_budget:
sampled_dataset = True
global_budget_reached = True
break
data = ro.data or {}
total_rows += 1
rows_scanned_this_dataset += 1
fqdn = ''
for c in cols['fqdn']:
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
rport = _clean(data.get(pc))
if rport:
break
connections[(host_key, rip, rport)] += 1
conn_key = (host_key, rip, rport)
if max_connections and len(connections) >= max_connections and conn_key not in connections:
dropped_connections += 1
continue
connections[conn_key] += 1
offset += batch_size
if sampled_dataset:
sampled_dataset_count += 1
logger.info(
"Host inventory sampling for dataset %s (%d rows scanned)",
ds.id,
rows_scanned_this_dataset,
)
break
last_row_index = rows[-1].row_index
if len(rows) < batch_size:
break
if global_budget_reached:
logger.info(
"Host inventory global row budget reached for hunt %s at %d rows",
hunt_id,
total_rows,
)
break
# Post-process hosts
for h in hosts.values():
if not h['os'] and h['fqdn']:
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
"total_rows_scanned": total_rows,
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
"hosts_with_users": sum(1 for h in host_list if h['users']),
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
"sampled_datasets": sampled_dataset_count,
"global_budget_reached": global_budget_reached,
"dropped_connections": dropped_connections,
},
}
}

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import re
import json
import logging
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
WILE_URL = f"{settings.wile_url}/api/generate"
# Velociraptor client IDs (C.hex) are not real hostnames
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
async def _get_triage_summary(db, dataset_id: str) -> str:
result = await db.execute(
@@ -154,7 +158,7 @@ async def profile_host(
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
except Exception as e:
logger.error("Failed to profile host %s: %s", hostname, e)
logger.error("Failed to profile host %s: %r", hostname, e)
profile = HostProfile(
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
risk_score=0.0, risk_level="unknown",
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
if h not in hostnames:
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
# Filter out Velociraptor client IDs - not real hostnames
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
skipped = len(hostnames) - len(real_hosts)
if skipped:
logger.info("Skipped %d Velociraptor client IDs", skipped)
hostnames = real_hosts
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)

View File

@@ -1,8 +1,8 @@
"""Async job queue for background AI tasks.
Manages triage, profiling, report generation, anomaly detection,
and data queries as trackable jobs with status, progress, and
cancellation support.
keyword scanning, IOC extraction, and data queries as trackable
jobs with status, progress, and cancellation support.
"""
from __future__ import annotations
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Coroutine, Optional
from app.config import settings
logger = logging.getLogger(__name__)
@@ -32,6 +34,18 @@ class JobType(str, Enum):
REPORT = "report"
ANOMALY = "anomaly"
QUERY = "query"
HOST_INVENTORY = "host_inventory"
KEYWORD_SCAN = "keyword_scan"
IOC_EXTRACT = "ioc_extract"
# Job types that form the automatic upload pipeline
PIPELINE_JOB_TYPES = frozenset({
JobType.TRIAGE,
JobType.ANOMALY,
JobType.KEYWORD_SCAN,
JobType.IOC_EXTRACT,
})
@dataclass
@@ -82,11 +96,7 @@ class Job:
class JobQueue:
"""In-memory async job queue with concurrency control.
Jobs are tracked by ID and can be listed, polled, or cancelled.
A configurable number of workers process jobs from the queue.
"""
"""In-memory async job queue with concurrency control."""
def __init__(self, max_workers: int = 3):
self._jobs: dict[str, Job] = {}
@@ -95,47 +105,56 @@ class JobQueue:
self._workers: list[asyncio.Task] = []
self._handlers: dict[JobType, Callable] = {}
self._started = False
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
self._cleanup_task: asyncio.Task | None = None
def register_handler(
self,
job_type: JobType,
handler: Callable[[Job], Coroutine],
):
"""Register an async handler for a job type.
Handler signature: async def handler(job: Job) -> Any
The handler can update job.progress and job.message during execution.
It should check job.is_cancelled periodically and return early.
"""
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
self._handlers[job_type] = handler
logger.info(f"Registered handler for {job_type.value}")
def on_completion(self, callback: Callable[[Job], Coroutine]):
"""Register a callback invoked after any job completes or fails."""
self._completion_callbacks.append(callback)
async def start(self):
"""Start worker tasks."""
if self._started:
return
self._started = True
for i in range(self._max_workers):
task = asyncio.create_task(self._worker(i))
self._workers.append(task)
if not self._cleanup_task or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info(f"Job queue started with {self._max_workers} workers")
async def stop(self):
"""Stop all workers."""
self._started = False
for w in self._workers:
w.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
if self._cleanup_task:
self._cleanup_task.cancel()
await asyncio.gather(self._cleanup_task, return_exceptions=True)
self._cleanup_task = None
logger.info("Job queue stopped")
def submit(self, job_type: JobType, **params) -> Job:
"""Submit a new job. Returns the Job object immediately."""
job = Job(
id=str(uuid.uuid4()),
job_type=job_type,
params=params,
)
# Soft backpressure: prefer dedupe over queue amplification
dedupe_job = self._find_active_duplicate(job_type, params)
if dedupe_job is not None:
logger.info(
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
)
return dedupe_job
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
logger.warning(
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
)
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
self._jobs[job.id] = job
self._queue.put_nowait(job.id)
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
@@ -144,6 +163,22 @@ class JobQueue:
def get_job(self, job_id: str) -> Job | None:
return self._jobs.get(job_id)
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
"""Return queued/running job with same key workload to prevent duplicate storms."""
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
if not sig:
return None
for j in self._jobs.values():
if j.job_type != job_type:
continue
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
continue
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
if sig == other_sig:
return j
return None
def cancel_job(self, job_id: str) -> bool:
job = self._jobs.get(job_id)
if not job:
@@ -153,13 +188,7 @@ class JobQueue:
job.cancel()
return True
def list_jobs(
self,
status: JobStatus | None = None,
job_type: JobType | None = None,
limit: int = 50,
) -> list[dict]:
"""List jobs, newest first."""
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
if status:
jobs = [j for j in jobs if j.status == status]
@@ -168,7 +197,6 @@ class JobQueue:
return [j.to_dict() for j in jobs[:limit]]
def get_stats(self) -> dict:
"""Get queue statistics."""
by_status = {}
for j in self._jobs.values():
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
@@ -177,26 +205,58 @@ class JobQueue:
"queued": self._queue.qsize(),
"by_status": by_status,
"workers": self._max_workers,
"active_workers": sum(
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
),
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
}
def is_backlogged(self) -> bool:
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
def can_accept(self, reserve: int = 0) -> bool:
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
def cleanup(self, max_age_seconds: float = 3600):
"""Remove old completed/failed/cancelled jobs."""
now = time.time()
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
to_remove = [
jid for jid, j in self._jobs.items()
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
and (now - j.created_at) > max_age_seconds
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
]
# Also cap retained terminal jobs to avoid unbounded memory growth
terminal_jobs = sorted(
[j for j in self._jobs.values() if j.status in terminal_states],
key=lambda j: j.created_at,
reverse=True,
)
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
to_remove.extend([j.id for j in overflow])
removed = 0
for jid in set(to_remove):
if jid in self._jobs:
del self._jobs[jid]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
async def _cleanup_loop(self):
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
while self._started:
try:
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
except Exception as e:
logger.warning(f"Job queue cleanup loop error: {e}")
await asyncio.sleep(interval)
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
"""Find all pipeline jobs for a given dataset_id."""
return [
j for j in self._jobs.values()
if j.job_type in PIPELINE_JOB_TYPES
and j.params.get("dataset_id") == dataset_id
]
for jid in to_remove:
del self._jobs[jid]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old jobs")
async def _worker(self, worker_id: int):
"""Worker loop: pull jobs from queue and execute handlers."""
logger.info(f"Worker {worker_id} started")
while self._started:
try:
@@ -220,7 +280,10 @@ class JobQueue:
job.status = JobStatus.RUNNING
job.started_at = time.time()
if job.progress <= 0:
job.progress = 5.0
job.message = "Running..."
await _sync_processing_task(job)
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
try:
@@ -231,38 +294,111 @@ class JobQueue:
job.result = result
job.message = "Completed"
job.completed_at = time.time()
logger.info(
f"Worker {worker_id}: completed {job.id} "
f"in {job.elapsed_ms}ms"
)
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
except Exception as e:
if not job.is_cancelled:
job.status = JobStatus.FAILED
job.error = str(e)
job.message = f"Failed: {e}"
job.completed_at = time.time()
logger.error(
f"Worker {worker_id}: failed {job.id}: {e}",
exc_info=True,
)
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
if job.is_cancelled and not job.completed_at:
job.completed_at = time.time()
await _sync_processing_task(job)
# Fire completion callbacks
for cb in self._completion_callbacks:
try:
await cb(job)
except Exception as cb_err:
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
# Singleton + job handlers
async def _sync_processing_task(job: Job):
"""Persist latest job state into processing_tasks (if linked by job_id)."""
from datetime import datetime, timezone
from sqlalchemy import update
job_queue = JobQueue(max_workers=3)
try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
values = {
"status": job.status.value,
"progress": float(job.progress),
"message": job.message,
"error": job.error,
}
if job.started_at:
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
if job.completed_at:
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
async with async_session_factory() as db:
await db.execute(
update(ProcessingTask)
.where(ProcessingTask.job_id == job.id)
.values(**values)
)
await db.commit()
except Exception as e:
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
# -- Singleton + job handlers --
job_queue = JobQueue(max_workers=5)
async def _handle_triage(job: Job):
"""Triage handler."""
"""Triage handler - chains HOST_PROFILE after completion."""
from app.services.triage import triage_dataset
dataset_id = job.params.get("dataset_id")
job.message = f"Triaging dataset {dataset_id}"
results = await triage_dataset(dataset_id)
return {"count": len(results) if results else 0}
await triage_dataset(dataset_id)
# Chain: trigger host profiling now that triage results exist
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import select
try:
async with async_session_factory() as db:
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
row = ds.first()
hunt_id = row[0] if row else None
if hunt_id:
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
try:
from sqlalchemy import select
from app.db.models import ProcessingTask
async with async_session_factory() as db:
existing = await db.execute(
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
)
if existing.first() is None:
db.add(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset_id,
job_id=hp_job.id,
stage="host_profile",
status="queued",
progress=0.0,
message="Queued",
))
await db.commit()
except Exception as persist_err:
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
except Exception as e:
logger.warning(f"Failed to chain host profile after triage: {e}")
return {"dataset_id": dataset_id}
async def _handle_host_profile(job: Job):
"""Host profiling handler."""
from app.services.host_profiler import profile_all_hosts, profile_host
hunt_id = job.params.get("hunt_id")
hostname = job.params.get("hostname")
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
async def _handle_report(job: Job):
"""Report generation handler."""
from app.services.report_generator import generate_report
hunt_id = job.params.get("hunt_id")
job.message = f"Generating report for hunt {hunt_id}"
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
async def _handle_anomaly(job: Job):
"""Anomaly detection handler."""
from app.services.anomaly_detector import detect_anomalies
dataset_id = job.params.get("dataset_id")
k = job.params.get("k", 3)
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
async def _handle_query(job: Job):
"""Data query handler (non-streaming)."""
from app.services.data_query import query_dataset
dataset_id = job.params.get("dataset_id")
question = job.params.get("question", "")
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
return {"answer": answer}
async def _handle_host_inventory(job: Job):
from app.db import async_session_factory
from app.services.host_inventory import build_host_inventory, inventory_cache
hunt_id = job.params.get("hunt_id")
if not hunt_id:
raise ValueError("hunt_id required")
inventory_cache.set_building(hunt_id)
job.message = f"Building host inventory for hunt {hunt_id}"
try:
async with async_session_factory() as db:
result = await build_host_inventory(hunt_id, db)
inventory_cache.put(hunt_id, result)
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
except Exception:
inventory_cache.clear_building(hunt_id)
raise
async def _handle_keyword_scan(job: Job):
"""AUP keyword scan handler."""
from app.db import async_session_factory
from app.services.scanner import KeywordScanner, keyword_scan_cache
dataset_id = job.params.get("dataset_id")
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
async with async_session_factory() as db:
scanner = KeywordScanner(db)
result = await scanner.scan(dataset_ids=[dataset_id])
# Cache dataset-only result for fast API reuse
if dataset_id:
keyword_scan_cache.put(dataset_id, result)
hits = result.get("total_hits", 0)
job.message = f"Keyword scan complete: {hits} hits"
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
async def _handle_ioc_extract(job: Job):
"""IOC extraction handler."""
from app.db import async_session_factory
from app.services.ioc_extractor import extract_iocs_from_dataset
dataset_id = job.params.get("dataset_id")
job.message = f"Extracting IOCs from dataset {dataset_id}"
async with async_session_factory() as db:
iocs = await extract_iocs_from_dataset(dataset_id, db)
total = sum(len(v) for v in iocs.values())
job.message = f"IOC extraction complete: {total} IOCs found"
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
async def _on_pipeline_job_complete(job: Job):
"""Update Dataset.processing_status when all pipeline jobs finish."""
if job.job_type not in PIPELINE_JOB_TYPES:
return
dataset_id = job.params.get("dataset_id")
if not dataset_id:
return
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
if not pipeline_jobs:
return
all_done = all(
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
for j in pipeline_jobs
)
if not all_done:
return
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
new_status = "completed_with_errors" if any_failed else "completed"
try:
from app.db import async_session_factory
from app.db.models import Dataset
from sqlalchemy import update
async with async_session_factory() as db:
await db.execute(
update(Dataset)
.where(Dataset.id == dataset_id)
.values(processing_status=new_status)
)
await db.commit()
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
except Exception as e:
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
async def reconcile_stale_processing_tasks() -> int:
"""Mark queued/running processing tasks from prior runs as failed."""
from datetime import datetime, timezone
from sqlalchemy import update
try:
from app.db import async_session_factory
from app.db.models import ProcessingTask
now = datetime.now(timezone.utc)
async with async_session_factory() as db:
result = await db.execute(
update(ProcessingTask)
.where(ProcessingTask.status.in_(["queued", "running"]))
.values(
status="failed",
error="Recovered after service restart before task completion",
message="Recovered stale task after restart",
completed_at=now,
)
)
await db.commit()
updated = int(result.rowcount or 0)
if updated:
logger.warning(
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
updated,
)
return updated
except Exception as e:
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
return 0
def register_all_handlers():
"""Register all job handlers."""
"""Register all job handlers and completion callbacks."""
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
job_queue.register_handler(JobType.REPORT, _handle_report)
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
job_queue.register_handler(JobType.QUERY, _handle_query)
job_queue.register_handler(JobType.QUERY, _handle_query)
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
job_queue.on_completion(_on_pipeline_job_complete)

View File

@@ -1,4 +1,4 @@
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
messages for keyword matches.
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from sqlalchemy import select, func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db.models import (
KeywordTheme,
Keyword,
DatasetRow,
Dataset,
Hunt,
Annotation,
Message,
Conversation,
)
logger = logging.getLogger(__name__)
BATCH_SIZE = 500
BATCH_SIZE = 200
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
"""Best-effort extraction of hostname and user from a dataset row."""
if not data:
return None, None
host_keys = (
'hostname', 'host_name', 'host', 'computer_name', 'computer',
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
)
user_keys = (
'username', 'user_name', 'user', 'account_name',
'logged_in_user', 'samaccountname', 'sam_account_name',
)
def pick(keys):
for k in keys:
for actual_key, v in data.items():
if actual_key.lower() == k and v not in (None, ''):
return str(v)
return None
return pick(host_keys), pick(user_keys)
@dataclass
@@ -39,6 +64,8 @@ class ScanHit:
matched_value: str
row_index: int | None = None
dataset_name: str | None = None
hostname: str | None = None
username: str | None = None
@dataclass
@@ -50,21 +77,54 @@ class ScanResult:
rows_scanned: int = 0
@dataclass
class KeywordScanCacheEntry:
dataset_id: str
result: dict
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
class KeywordScanCache:
"""In-memory per-dataset cache for dataset-only keyword scans.
This enables fast-path reads when users run AUP scans against datasets that
were already scanned during upload pipeline processing.
"""
def __init__(self):
self._entries: dict[str, KeywordScanCacheEntry] = {}
def put(self, dataset_id: str, result: dict):
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
return self._entries.get(dataset_id)
def invalidate_dataset(self, dataset_id: str):
self._entries.pop(dataset_id, None)
def clear(self):
self._entries.clear()
keyword_scan_cache = KeywordScanCache()
class KeywordScanner:
"""Scans multiple data sources for keyword/regex matches."""
def __init__(self, db: AsyncSession):
self.db = db
# ── Public API ────────────────────────────────────────────────────
# Public API
async def scan(
self,
dataset_ids: list[str] | None = None,
theme_ids: list[str] | None = None,
scan_hunts: bool = True,
scan_annotations: bool = True,
scan_messages: bool = True,
scan_hunts: bool = False,
scan_annotations: bool = False,
scan_messages: bool = False,
) -> dict:
"""Run a full AUP scan and return dict matching ScanResponse."""
# Load themes + keywords
@@ -103,7 +163,7 @@ class KeywordScanner:
"rows_scanned": result.rows_scanned,
}
# ── Internal ──────────────────────────────────────────────────────
# Internal
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
@@ -143,6 +203,8 @@ class KeywordScanner:
hits: list[ScanHit],
row_index: int | None = None,
dataset_name: str | None = None,
hostname: str | None = None,
username: str | None = None,
) -> None:
"""Check text against all compiled patterns, append hits."""
if not text:
@@ -150,8 +212,7 @@ class KeywordScanner:
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
for kw_value, pat in keyword_patterns:
if pat.search(text):
# Truncate matched_value for display
matched_preview = text[:200] + ("" if len(text) > 200 else "")
matched_preview = text[:200] + ("" if len(text) > 200 else "")
hits.append(ScanHit(
theme_name=theme_name,
theme_color=theme_color,
@@ -162,13 +223,14 @@ class KeywordScanner:
matched_value=matched_preview,
row_index=row_index,
dataset_name=dataset_name,
hostname=hostname,
username=username,
))
async def _scan_datasets(
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
) -> None:
"""Scan dataset rows in batches."""
# Build dataset name lookup
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
ds_q = select(Dataset.id, Dataset.name)
if dataset_ids:
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
@@ -178,37 +240,66 @@ class KeywordScanner:
if not ds_map:
return
# Iterate rows in batches
offset = 0
row_q_base = select(DatasetRow).where(
DatasetRow.dataset_id.in_(list(ds_map.keys()))
).order_by(DatasetRow.id)
import asyncio
while True:
rows_result = await self.db.execute(
row_q_base.offset(offset).limit(BATCH_SIZE)
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
budget_reached = False
for ds_id, ds_name in ds_map.items():
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
last_id = 0
while True:
if max_rows and result.rows_scanned >= max_rows:
budget_reached = True
break
rows_result = await self.db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == ds_id)
.where(DatasetRow.id > last_id)
.order_by(DatasetRow.id)
.limit(BATCH_SIZE)
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
hostname, username = _infer_hostname_and_user(data)
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text,
patterns,
"dataset_row",
row.id,
col_name,
result.hits,
row_index=row.row_index,
dataset_name=ds_name,
hostname=hostname,
username=username,
)
last_id = rows[-1].id
await asyncio.sleep(0)
if len(rows) < BATCH_SIZE:
break
if budget_reached:
break
if budget_reached:
logger.warning(
"AUP scan row budget reached (%d rows). Returning partial results.",
result.rows_scanned,
)
rows = rows_result.scalars().all()
if not rows:
break
for row in rows:
result.rows_scanned += 1
data = row.data or {}
for col_name, cell_value in data.items():
if cell_value is None:
continue
text = str(cell_value)
self._match_text(
text, patterns, "dataset_row", row.id,
col_name, result.hits,
row_index=row.row_index,
dataset_name=ds_map.get(row.dataset_id),
)
offset += BATCH_SIZE
if len(rows) < BATCH_SIZE:
break
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
"""Scan hunt names and descriptions."""

View File

@@ -1,4 +1,4 @@
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
from __future__ import annotations
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
logger = logging.getLogger(__name__)
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
ARTIFACT_FOCUS = {
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
rows_result = await db.execute(
select(DatasetRow)
.where(DatasetRow.dataset_id == dataset_id)
.order_by(DatasetRow.row_number)
.order_by(DatasetRow.row_index)
.offset(offset)
.limit(batch_size)
)
@@ -167,4 +167,4 @@ Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."
offset += batch_size
logger.info("Triage complete for dataset %s", dataset_id)
logger.info("Triage complete for dataset %s", dataset_id)

View File

@@ -0,0 +1,124 @@
"""Tests for execution-mode behavior in /api/agent/assist."""
import io
import pytest
@pytest.mark.asyncio
async def test_agent_assist_policy_query_executes_scan(client):
# 1) Create hunt
h = await client.post("/api/hunts", json={"name": "Policy Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
# 2) Upload browser-history-like CSV
csv_bytes = (
b"User,visited_url,title,ClientId,Fqdn\n"
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n"
)
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
# 3) Ensure policy theme/keyword exists
t = await client.post(
"/api/keywords/themes",
json={
"name": "Adult Content",
"color": "#e91e63",
"enabled": True,
},
)
assert t.status_code in (201, 409)
themes = await client.get("/api/keywords/themes")
assert themes.status_code == 200
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
k = await client.post(
f"/api/keywords/themes/{adult['id']}/keywords",
json={"value": "pornhub", "is_regex": False},
)
assert k.status_code in (201, 409)
# 4) Execution-mode query
q = await client.post(
"/api/agent/assist",
json={
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
"hunt_id": hunt_id,
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] == "execution:keyword_scanner"
assert body["execution"] is not None
assert body["execution"]["policy_hits"] >= 1
assert len(body["execution"]["top_user_hosts"]) >= 1
@pytest.mark.asyncio
async def test_agent_assist_execution_preference_off_stays_advisory(client):
h = await client.post("/api/hunts", json={"name": "No Exec Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
q = await client.post(
"/api/agent/assist",
json={
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
"hunt_id": hunt_id,
"execution_preference": "off",
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] != "execution:keyword_scanner"
assert body["execution"] is None
@pytest.mark.asyncio
async def test_agent_assist_execution_preference_force_executes(client):
# Create hunt + dataset even when the query text is not policy-specific
h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"})
assert h.status_code == 200
hunt_id = h.json()["id"]
csv_bytes = (
b"User,visited_url,title,ClientId,Fqdn\n"
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
)
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
t = await client.post(
"/api/keywords/themes",
json={"name": "Adult Content", "color": "#e91e63", "enabled": True},
)
assert t.status_code in (201, 409)
themes = await client.get("/api/keywords/themes")
assert themes.status_code == 200
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
k = await client.post(
f"/api/keywords/themes/{adult['id']}/keywords",
json={"value": "pornhub", "is_regex": False},
)
assert k.status_code in (201, 409)
q = await client.post(
"/api/agent/assist",
json={
"query": "Summarize notable activity in this hunt.",
"hunt_id": hunt_id,
"execution_preference": "force",
},
)
assert q.status_code == 200
body = q.json()
assert body["model_used"] == "execution:keyword_scanner"
assert body["execution"] is not None

View File

@@ -77,6 +77,26 @@ class TestHuntEndpoints:
assert resp.status_code == 404
async def test_hunt_progress(self, client):
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
hunt_id = create.json()["id"]
# attach one dataset so progress has scope
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
assert up.status_code == 200
res = await client.get(f"/api/hunts/{hunt_id}/progress")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert "progress_percent" in body
assert "dataset_total" in body
assert "network_status" in body
@pytest.mark.asyncio
class TestDatasetEndpoints:
"""Test dataset upload and retrieval."""

View File

@@ -1,4 +1,4 @@
"""Tests for CSV parser and normalizer services."""
"""Tests for CSV parser and normalizer services."""
import pytest
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
@@ -43,8 +43,9 @@ class TestCSVParser:
assert len(rows) == 2
def test_parse_empty_file(self):
with pytest.raises(Exception):
parse_csv_bytes(b"")
rows, meta = parse_csv_bytes(b"")
assert len(rows) == 0
assert meta["row_count"] == 0
def test_detect_encoding_utf8(self):
enc = detect_encoding(SAMPLE_CSV)
@@ -53,17 +54,15 @@ class TestCSVParser:
def test_infer_column_types(self):
types = infer_column_types(
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
"src_ip",
[{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}],
)
assert types == "ip"
assert types["src_ip"] == "ip"
def test_infer_column_types_hash(self):
types = infer_column_types(
["d41d8cd98f00b204e9800998ecf8427e"],
"hash",
[{"hash": "d41d8cd98f00b204e9800998ecf8427e"}],
)
assert types == "hash_md5"
assert types["hash"] == "hash_md5"
class TestNormalizer:
@@ -94,7 +93,7 @@ class TestNormalizer:
start, end = detect_time_range(rows, column_mapping)
# Should detect time range from timestamp column
if start:
assert "2025" in start
assert "2025" in str(start)
def test_normalize_rows(self):
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
@@ -102,3 +101,6 @@ class TestNormalizer:
normalized = normalize_rows(rows, mapping)
assert len(normalized) == 1
assert normalized[0].get("src_ip") == "10.0.0.1"

View File

@@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient):
assert "total_hits" in data
# powershell should match at least one row
assert data["total_hits"] > 0
@pytest.mark.asyncio
async def test_quick_scan_cache_hit(client: AsyncClient):
"""Second quick scan should return cache hit metadata."""
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
tid = theme_res.json()["id"]
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
from tests.conftest import SAMPLE_CSV
import io
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
upload = await client.post("/api/datasets/upload", files=files)
ds_id = upload.json()["id"]
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert first.status_code == 200
assert first.json().get("cache_status") in ("miss", "hit")
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
assert second.status_code == 200
body = second.json()
assert body.get("cache_used") is True
assert body.get("cache_status") == "hit"

View File

@@ -0,0 +1,84 @@
"""Tests for network inventory endpoints and cache/polling behavior."""
import io
import pytest
from app.services.host_inventory import inventory_cache
from tests.conftest import SAMPLE_CSV
@pytest.mark.asyncio
async def test_inventory_status_none_for_unknown_hunt(client):
hunt_id = "hunt-does-not-exist"
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
assert res.status_code == 200
body = res.json()
assert body["hunt_id"] == hunt_id
assert body["status"] == "none"
@pytest.mark.asyncio
async def test_host_inventory_cold_cache_returns_202(client):
# Create hunt and upload dataset linked to that hunt
hunt = await client.post("/api/hunts", json={"name": "Net Hunt"})
hunt_id = hunt.json()["id"]
files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id})
assert up.status_code == 200
# Ensure cache is cold for this hunt
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 202
body = res.json()
assert body["status"] == "building"
@pytest.mark.asyncio
async def test_host_inventory_ready_cache_returns_200(client):
hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"})
hunt_id = hunt.json()["id"]
mock_inventory = {
"hosts": [
{
"id": "host-1",
"hostname": "HOST-1",
"fqdn": "HOST-1.local",
"client_id": "C.1234abcd",
"ips": ["10.0.0.10"],
"os": "Windows 10",
"users": ["alice"],
"datasets": ["test"],
"row_count": 5,
}
],
"connections": [],
"stats": {
"total_hosts": 1,
"hosts_with_ips": 1,
"hosts_with_users": 1,
"total_datasets_scanned": 1,
"total_rows_scanned": 5,
},
}
inventory_cache.put(hunt_id, mock_inventory)
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 200
body = res.json()
assert body["stats"]["total_hosts"] == 1
assert len(body["hosts"]) == 1
assert body["hosts"][0]["hostname"] == "HOST-1"
status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
assert status_res.status_code == 200
assert status_res.json()["status"] == "ready"

View File

@@ -0,0 +1,82 @@
"""Scale-oriented network endpoint tests (summary/subgraph/backpressure)."""
import pytest
from app.config import settings
from app.services.host_inventory import inventory_cache
@pytest.mark.asyncio
async def test_network_summary_from_cache(client):
hunt_id = "scale-hunt-summary"
inv = {
"hosts": [
{"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50},
{"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10},
],
"connections": [
{"source": "h1", "target": "8.8.8.8", "count": 7},
{"source": "h1", "target": "h2", "count": 3},
],
"stats": {"total_hosts": 2, "total_rows_scanned": 60},
}
inventory_cache.put(hunt_id, inv)
res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1")
assert res.status_code == 200
body = res.json()
assert body["stats"]["total_hosts"] == 2
assert len(body["top_hosts"]) == 1
assert body["top_hosts"][0]["id"] == "h1"
@pytest.mark.asyncio
async def test_network_subgraph_truncates(client):
hunt_id = "scale-hunt-subgraph"
inv = {
"hosts": [
{"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i}
for i in range(1, 8)
],
"connections": [
{"source": "h1", "target": "h2", "count": 20},
{"source": "h1", "target": "h3", "count": 15},
{"source": "h2", "target": "h4", "count": 5},
{"source": "h3", "target": "h5", "count": 4},
],
"stats": {"total_hosts": 7, "total_rows_scanned": 999},
}
inventory_cache.put(hunt_id, inv)
res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2")
assert res.status_code == 200
body = res.json()
assert len(body["hosts"]) <= 3
assert len(body["connections"]) <= 2
assert body["stats"]["truncated"] is True
@pytest.mark.asyncio
async def test_manual_job_submit_backpressure_returns_429(client):
old = settings.JOB_QUEUE_MAX_BACKLOG
settings.JOB_QUEUE_MAX_BACKLOG = 0
try:
res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}})
assert res.status_code == 429
finally:
settings.JOB_QUEUE_MAX_BACKLOG = old
@pytest.mark.asyncio
async def test_network_host_inventory_deferred_when_queue_backlogged(client):
hunt_id = "deferred-hunt"
inventory_cache.invalidate(hunt_id)
inventory_cache.clear_building(hunt_id)
old = settings.JOB_QUEUE_MAX_BACKLOG
settings.JOB_QUEUE_MAX_BACKLOG = 0
try:
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
assert res.status_code == 202
body = res.json()
assert body["status"] == "deferred"
finally:
settings.JOB_QUEUE_MAX_BACKLOG = old

View File

@@ -0,0 +1,203 @@
"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches."""
import pytest
import pytest_asyncio
class TestMitreRoutes:
"""Tests for /api/mitre endpoints."""
@pytest.mark.asyncio
async def test_mitre_coverage_empty(self, client):
resp = await client.get("/api/mitre/coverage")
assert resp.status_code == 200
data = resp.json()
assert "tactics" in data
assert "technique_count" in data
assert data["technique_count"] == 0
assert len(data["tactics"]) == 14 # 14 MITRE tactics
@pytest.mark.asyncio
async def test_mitre_coverage_with_hunt_filter(self, client):
resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent")
assert resp.status_code == 200
assert resp.json()["technique_count"] == 0
class TestTimelineRoutes:
"""Tests for /api/timeline endpoints."""
@pytest.mark.asyncio
async def test_timeline_hunt_not_found(self, client):
resp = await client.get("/api/timeline/hunt/nonexistent")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_timeline_with_hunt(self, client):
# Create a hunt first
hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"})
assert hunt_resp.status_code in (200, 201)
hunt_id = hunt_resp.json()["id"]
resp = await client.get(f"/api/timeline/hunt/{hunt_id}")
assert resp.status_code == 200
data = resp.json()
assert data["hunt_id"] == hunt_id
assert "events" in data
assert "datasets" in data
class TestPlaybookRoutes:
"""Tests for /api/playbooks endpoints."""
@pytest.mark.asyncio
async def test_list_playbooks_empty(self, client):
resp = await client.get("/api/playbooks")
assert resp.status_code == 200
assert resp.json()["playbooks"] == []
@pytest.mark.asyncio
async def test_get_templates(self, client):
resp = await client.get("/api/playbooks/templates")
assert resp.status_code == 200
templates = resp.json()["templates"]
assert len(templates) >= 2
assert templates[0]["name"] == "Standard Threat Hunt"
@pytest.mark.asyncio
async def test_create_playbook(self, client):
resp = await client.post("/api/playbooks", json={
"name": "My Investigation",
"description": "Test playbook",
"steps": [
{"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"},
{"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"},
],
})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "My Investigation"
assert len(data["steps"]) == 2
@pytest.mark.asyncio
async def test_playbook_crud(self, client):
# Create
resp = await client.post("/api/playbooks", json={
"name": "CRUD Test",
"steps": [{"title": "Do something"}],
})
assert resp.status_code == 201
pb_id = resp.json()["id"]
# Get
resp = await client.get(f"/api/playbooks/{pb_id}")
assert resp.status_code == 200
assert resp.json()["name"] == "CRUD Test"
assert len(resp.json()["steps"]) == 1
# Update
resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"})
assert resp.status_code == 200
# Delete
resp = await client.delete(f"/api/playbooks/{pb_id}")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_playbook_step_completion(self, client):
# Create with step
resp = await client.post("/api/playbooks", json={
"name": "Step Test",
"steps": [{"title": "Task 1"}],
})
pb_id = resp.json()["id"]
# Get to find step ID
resp = await client.get(f"/api/playbooks/{pb_id}")
steps = resp.json()["steps"]
step_id = steps[0]["id"]
assert steps[0]["is_completed"] is False
# Mark complete
resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"})
assert resp.status_code == 200
assert resp.json()["is_completed"] is True
class TestSavedSearchRoutes:
"""Tests for /api/searches endpoints."""
@pytest.mark.asyncio
async def test_list_empty(self, client):
resp = await client.get("/api/searches")
assert resp.status_code == 200
assert resp.json()["searches"] == []
@pytest.mark.asyncio
async def test_create_saved_search(self, client):
resp = await client.post("/api/searches", json={
"name": "Suspicious IPs",
"search_type": "ioc_search",
"query_params": {"ioc_value": "203.0.113"},
})
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Suspicious IPs"
assert data["search_type"] == "ioc_search"
@pytest.mark.asyncio
async def test_search_crud(self, client):
# Create
resp = await client.post("/api/searches", json={
"name": "Test Query",
"search_type": "keyword_scan",
"query_params": {"theme": "malware"},
})
s_id = resp.json()["id"]
# Get
resp = await client.get(f"/api/searches/{s_id}")
assert resp.status_code == 200
# Update
resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"})
assert resp.status_code == 200
# Run
resp = await client.post(f"/api/searches/{s_id}/run")
assert resp.status_code == 200
data = resp.json()
assert "result_count" in data
assert "delta" in data
# Delete
resp = await client.delete(f"/api/searches/{s_id}")
assert resp.status_code == 200
class TestStixExport:
"""Tests for /api/export/stix endpoints."""
@pytest.mark.asyncio
async def test_stix_export_hunt_not_found(self, client):
resp = await client.get("/api/export/stix/nonexistent-id")
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_stix_export_empty_hunt(self, client):
"""Export from a real hunt with no data returns valid but minimal bundle."""
hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"})
assert hunt_resp.status_code in (200, 201)
hunt_id = hunt_resp.json()["id"]
resp = await client.get(f"/api/export/stix/{hunt_id}")
assert resp.status_code == 200
data = resp.json()
assert data["type"] == "bundle"
assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle
assert "objects" in data
# At minimum should have the identity object
types = [o["type"] for o in data["objects"]]
assert "identity" in types

Binary file not shown.