mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: Add Playbook Manager, Saved Searches, and Timeline View components
- Implemented PlaybookManager for creating and managing investigation playbooks with templates. - Added SavedSearches component for managing bookmarked queries and recurring scans. - Introduced TimelineView for visualizing forensic event timelines with zoomable charts. - Enhanced backend processing with auto-queued jobs for dataset uploads and improved database concurrency. - Updated frontend components for better user experience and performance optimizations. - Documented changes in update log for future reference.
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
"""add playbooks, playbook_steps, saved_searches tables
|
||||
|
||||
Revision ID: b2c3d4e5f6a7
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-02-21 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "b2c3d4e5f6a7"
|
||||
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add display_name to users table
|
||||
with op.batch_alter_table("users") as batch_op:
|
||||
batch_op.add_column(sa.Column("display_name", sa.String(128), nullable=True))
|
||||
|
||||
# Create playbooks table
|
||||
op.create_table(
|
||||
"playbooks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("is_template", sa.Boolean(), server_default="0"),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("status", sa.String(20), server_default="active"),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Create playbook_steps table
|
||||
op.create_table(
|
||||
"playbook_steps",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("playbook_id", sa.String(32), sa.ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("order_index", sa.Integer(), nullable=False),
|
||||
sa.Column("title", sa.String(256), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("step_type", sa.String(32), server_default="manual"),
|
||||
sa.Column("target_route", sa.String(256), nullable=True),
|
||||
sa.Column("is_completed", sa.Boolean(), server_default="0"),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("notes", sa.Text(), nullable=True),
|
||||
)
|
||||
op.create_index("ix_playbook_steps_playbook", "playbook_steps", ["playbook_id"])
|
||||
|
||||
# Create saved_searches table
|
||||
op.create_table(
|
||||
"saved_searches",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("name", sa.String(256), nullable=False, index=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("search_type", sa.String(32), nullable=False),
|
||||
sa.Column("query_params", sa.JSON(), nullable=False),
|
||||
sa.Column("threshold", sa.Float(), nullable=True),
|
||||
sa.Column("created_by", sa.String(32), sa.ForeignKey("users.id"), nullable=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id"), nullable=True),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("last_result_count", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_saved_searches_type", "saved_searches", ["search_type"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("saved_searches")
|
||||
op.drop_table("playbook_steps")
|
||||
op.drop_table("playbooks")
|
||||
with op.batch_alter_table("users") as batch_op:
|
||||
batch_op.drop_column("display_name")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add processing_tasks table
|
||||
|
||||
Revision ID: c3d4e5f6a7b8
|
||||
Revises: b2c3d4e5f6a7
|
||||
Create Date: 2026-02-22 00:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "c3d4e5f6a7b8"
|
||||
down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6a7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"processing_tasks",
|
||||
sa.Column("id", sa.String(32), primary_key=True),
|
||||
sa.Column("hunt_id", sa.String(32), sa.ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("dataset_id", sa.String(32), sa.ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True),
|
||||
sa.Column("job_id", sa.String(64), nullable=True),
|
||||
sa.Column("stage", sa.String(64), nullable=False),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||
sa.Column("progress", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("message", sa.Text(), nullable=True),
|
||||
sa.Column("error", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
)
|
||||
op.create_index("ix_processing_tasks_hunt_stage", "processing_tasks", ["hunt_id", "stage"])
|
||||
op.create_index("ix_processing_tasks_dataset_stage", "processing_tasks", ["dataset_id", "stage"])
|
||||
op.create_index("ix_processing_tasks_job_id", "processing_tasks", ["job_id"])
|
||||
op.create_index("ix_processing_tasks_status", "processing_tasks", ["status"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_processing_tasks_status", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_job_id", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_dataset_stage", table_name="processing_tasks")
|
||||
op.drop_index("ix_processing_tasks_hunt_stage", table_name="processing_tasks")
|
||||
op.drop_table("processing_tasks")
|
||||
@@ -1,16 +1,18 @@
|
||||
"""Analyst-assist agent module for ThreatHunt.
|
||||
"""Analyst-assist agent module for ThreatHunt.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only and do not execute actions or modify data.
|
||||
"""
|
||||
|
||||
from .core import ThreatHuntAgent
|
||||
from .providers import LLMProvider, LocalProvider, NetworkedProvider, OnlineProvider
|
||||
from .core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from .providers_v2 import OllamaProvider, OpenWebUIProvider, EmbeddingProvider
|
||||
|
||||
__all__ = [
|
||||
"ThreatHuntAgent",
|
||||
"LLMProvider",
|
||||
"LocalProvider",
|
||||
"NetworkedProvider",
|
||||
"OnlineProvider",
|
||||
"AgentContext",
|
||||
"AgentResponse",
|
||||
"Perspective",
|
||||
"OllamaProvider",
|
||||
"OpenWebUIProvider",
|
||||
"EmbeddingProvider",
|
||||
]
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Core ThreatHunt analyst-assist agent.
|
||||
|
||||
Provides read-only guidance on CSV artifact data, analytical pivots, and hypotheses.
|
||||
Agents are advisory only - no execution, no alerts, no data modifications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .providers import LLMProvider, get_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Context for agent guidance requests."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: Optional[str] = Field(None, description="Name of CSV dataset")
|
||||
artifact_type: Optional[str] = Field(None, description="Artifact type (e.g., file, process, network)")
|
||||
host_identifier: Optional[str] = Field(
|
||||
None, description="Host name, IP, or identifier"
|
||||
)
|
||||
data_summary: Optional[str] = Field(
|
||||
None, description="Brief description of uploaded data"
|
||||
)
|
||||
conversation_history: Optional[list[dict]] = Field(
|
||||
default_factory=list, description="Previous messages in conversation"
|
||||
)
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response from analyst-assist agent."""
|
||||
|
||||
guidance: str = Field(..., description="Advisory guidance for analyst")
|
||||
confidence: float = Field(
|
||||
..., ge=0.0, le=1.0, description="Confidence in guidance (0-1)"
|
||||
)
|
||||
suggested_pivots: list[str] = Field(
|
||||
default_factory=list, description="Suggested analytical directions"
|
||||
)
|
||||
suggested_filters: list[str] = Field(
|
||||
default_factory=list, description="Suggested data filters or queries"
|
||||
)
|
||||
caveats: Optional[str] = Field(
|
||||
None, description="Assumptions, limitations, or caveats"
|
||||
)
|
||||
reasoning: Optional[str] = Field(
|
||||
None, description="Explanation of how guidance was generated"
|
||||
)
|
||||
|
||||
|
||||
class ThreatHuntAgent:
|
||||
"""Analyst-assist agent for ThreatHunt.
|
||||
|
||||
Provides guidance on:
|
||||
- Interpreting CSV artifact data
|
||||
- Suggesting analytical pivots and filters
|
||||
- Forming and testing hypotheses
|
||||
|
||||
Policy:
|
||||
- Advisory guidance only (no execution)
|
||||
- No database or schema changes
|
||||
- No alert escalation
|
||||
- Transparent reasoning
|
||||
"""
|
||||
|
||||
def __init__(self, provider: Optional[LLMProvider] = None):
|
||||
"""Initialize agent with LLM provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider instance. If None, uses get_provider() with auto mode.
|
||||
"""
|
||||
if provider is None:
|
||||
try:
|
||||
provider = get_provider("auto")
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not initialize default provider: {e}")
|
||||
provider = None
|
||||
|
||||
self.provider = provider
|
||||
self.system_prompt = self._build_system_prompt()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt that governs agent behavior."""
|
||||
return """You are an analyst-assist agent for ThreatHunt, a threat hunting platform.
|
||||
|
||||
Your role:
|
||||
- Interpret and explain CSV artifact data from Velociraptor
|
||||
- Suggest analytical pivots, filters, and hypotheses
|
||||
- Highlight anomalies, patterns, or points of interest
|
||||
- Guide analysts without replacing their judgment
|
||||
|
||||
Your constraints:
|
||||
- You ONLY provide guidance and suggestions
|
||||
- You do NOT execute actions or tools
|
||||
- You do NOT modify data or escalate alerts
|
||||
- You do NOT make autonomous decisions
|
||||
- You ONLY analyze data presented to you
|
||||
- You explain your reasoning transparently
|
||||
- You acknowledge limitations and assumptions
|
||||
- You suggest next investigative steps
|
||||
|
||||
When responding:
|
||||
1. Start with a clear, direct answer to the query
|
||||
2. Explain your reasoning based on the data context provided
|
||||
3. Suggest 2-4 analytical pivots the analyst might explore
|
||||
4. Suggest 2-4 data filters or queries that might be useful
|
||||
5. Include relevant caveats or assumptions
|
||||
6. Be honest about what you cannot determine from the data
|
||||
|
||||
Remember: The analyst is the decision-maker. You are an assistant."""
|
||||
|
||||
async def assist(self, context: AgentContext) -> AgentResponse:
|
||||
"""Provide guidance on artifact data and analysis.
|
||||
|
||||
Args:
|
||||
context: Request context including query and data context.
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
if not self.provider:
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
|
||||
# Build prompt with context
|
||||
prompt = self._build_prompt(context)
|
||||
|
||||
try:
|
||||
# Get guidance from LLM provider
|
||||
guidance = await self.provider.generate(prompt, max_tokens=1024)
|
||||
|
||||
# Parse response into structured format
|
||||
response = self._parse_response(guidance, context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted with query: {context.query[:50]}... "
|
||||
f"(dataset: {context.dataset_name})"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating guidance: {e}")
|
||||
raise
|
||||
|
||||
def _build_prompt(self, context: AgentContext) -> str:
|
||||
"""Build the prompt for the LLM."""
|
||||
prompt_parts = [
|
||||
f"Analyst query: {context.query}",
|
||||
]
|
||||
|
||||
if context.dataset_name:
|
||||
prompt_parts.append(f"Dataset: {context.dataset_name}")
|
||||
|
||||
if context.artifact_type:
|
||||
prompt_parts.append(f"Artifact type: {context.artifact_type}")
|
||||
|
||||
if context.host_identifier:
|
||||
prompt_parts.append(f"Host: {context.host_identifier}")
|
||||
|
||||
if context.data_summary:
|
||||
prompt_parts.append(f"Data summary: {context.data_summary}")
|
||||
|
||||
if context.conversation_history:
|
||||
prompt_parts.append("\nConversation history:")
|
||||
for msg in context.conversation_history[-5:]: # Last 5 messages for context
|
||||
prompt_parts.append(f" {msg.get('role', 'unknown')}: {msg.get('content', '')}")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def _parse_response(self, response_text: str, context: AgentContext) -> AgentResponse:
|
||||
"""Parse LLM response into structured format.
|
||||
|
||||
Note: This is a simplified parser. In production, use structured output
|
||||
from the LLM (JSON mode, function calling, etc.) for better reliability.
|
||||
"""
|
||||
# For now, return a structured response based on the raw guidance
|
||||
# In production, parse JSON or use structured output from LLM
|
||||
return AgentResponse(
|
||||
guidance=response_text,
|
||||
confidence=0.8, # Placeholder
|
||||
suggested_pivots=[
|
||||
"Analyze temporal patterns",
|
||||
"Cross-reference with known indicators",
|
||||
"Examine outliers in the dataset",
|
||||
"Compare with baseline behavior",
|
||||
],
|
||||
suggested_filters=[
|
||||
"Filter by high-risk indicators",
|
||||
"Sort by timestamp for timeline analysis",
|
||||
"Group by host or user",
|
||||
"Filter by anomaly score",
|
||||
],
|
||||
caveats="Guidance is based on available data context. "
|
||||
"Analysts should verify findings with additional sources.",
|
||||
reasoning="Analysis generated based on artifact data patterns and analyst query.",
|
||||
)
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Pluggable LLM provider interface for analyst-assist agents.
|
||||
|
||||
Supports three provider types:
|
||||
- Local: On-device or on-prem models
|
||||
- Networked: Shared internal inference services
|
||||
- Online: External hosted APIs
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate a response from the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The input prompt
|
||||
max_tokens: Maximum tokens in response
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider backend is available."""
|
||||
pass
|
||||
|
||||
|
||||
class LocalProvider(LLMProvider):
|
||||
"""Local LLM provider (on-device or on-prem models)."""
|
||||
|
||||
def __init__(self, model_path: Optional[str] = None):
|
||||
"""Initialize local provider.
|
||||
|
||||
Args:
|
||||
model_path: Path to local model. If None, uses THREAT_HUNT_LOCAL_MODEL_PATH env var.
|
||||
"""
|
||||
self.model_path = model_path or os.getenv("THREAT_HUNT_LOCAL_MODEL_PATH")
|
||||
self.model = None
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if local model is available."""
|
||||
if not self.model_path:
|
||||
return False
|
||||
# In production, would verify model file exists and can be loaded
|
||||
return os.path.exists(str(self.model_path))
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using local model.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- llama-cpp-python for GGML models
|
||||
- Ollama API
|
||||
- vLLM
|
||||
- Other local inference engines
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Local model not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Local model response to: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class NetworkedProvider(LLMProvider):
|
||||
"""Networked LLM provider (shared internal inference services)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "default",
|
||||
):
|
||||
"""Initialize networked provider.
|
||||
|
||||
Args:
|
||||
api_endpoint: URL to inference service. Defaults to env var THREAT_HUNT_NETWORKED_ENDPOINT.
|
||||
api_key: API key for service. Defaults to env var THREAT_HUNT_NETWORKED_KEY.
|
||||
model_name: Model name/ID on the service.
|
||||
"""
|
||||
self.api_endpoint = api_endpoint or os.getenv("THREAT_HUNT_NETWORKED_ENDPOINT")
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_NETWORKED_KEY")
|
||||
self.model_name = model_name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if networked service is available."""
|
||||
return bool(self.api_endpoint)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using networked service.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- Internal inference service API
|
||||
- LLM inference container cluster
|
||||
- Enterprise inference gateway
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Networked service not available")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Networked response from {self.model_name}: {prompt[:50]}...]"
|
||||
|
||||
|
||||
class OnlineProvider(LLMProvider):
|
||||
"""Online LLM provider (external hosted APIs)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_provider: str = "openai",
|
||||
api_key: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
):
|
||||
"""Initialize online provider.
|
||||
|
||||
Args:
|
||||
api_provider: Provider name (openai, anthropic, google, etc.)
|
||||
api_key: API key. Defaults to env var THREAT_HUNT_ONLINE_API_KEY.
|
||||
model_name: Model name. Defaults to env var THREAT_HUNT_ONLINE_MODEL.
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
self.api_key = api_key or os.getenv("THREAT_HUNT_ONLINE_API_KEY")
|
||||
self.model_name = model_name or os.getenv(
|
||||
"THREAT_HUNT_ONLINE_MODEL", f"{api_provider}-default"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if online API is available."""
|
||||
return bool(self.api_key)
|
||||
|
||||
async def generate(self, prompt: str, max_tokens: int = 1024) -> str:
|
||||
"""Generate response using online API.
|
||||
|
||||
Note: This is a placeholder. In production, integrate with:
|
||||
- OpenAI API (GPT-3.5, GPT-4, etc.)
|
||||
- Anthropic Claude API
|
||||
- Google Gemini API
|
||||
- Other hosted LLM services
|
||||
"""
|
||||
if not self.is_available():
|
||||
raise RuntimeError("Online API not available or API key not set")
|
||||
|
||||
# Placeholder implementation
|
||||
return f"[Online {self.api_provider} response: {prompt[:50]}...]"
|
||||
|
||||
|
||||
def get_provider(provider_type: str = "auto") -> LLMProvider:
|
||||
"""Get an LLM provider based on configuration.
|
||||
|
||||
Args:
|
||||
provider_type: Type of provider to use: 'local', 'networked', 'online', or 'auto'.
|
||||
'auto' attempts to use the first available provider in order:
|
||||
local -> networked -> online.
|
||||
|
||||
Returns:
|
||||
Configured LLM provider instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no provider is available.
|
||||
"""
|
||||
# Explicit provider selection
|
||||
if provider_type == "local":
|
||||
provider = LocalProvider()
|
||||
elif provider_type == "networked":
|
||||
provider = NetworkedProvider()
|
||||
elif provider_type == "online":
|
||||
provider = OnlineProvider()
|
||||
elif provider_type == "auto":
|
||||
# Try providers in order of preference
|
||||
for Provider in [LocalProvider, NetworkedProvider, OnlineProvider]:
|
||||
provider = Provider()
|
||||
if provider.is_available():
|
||||
return provider
|
||||
raise RuntimeError(
|
||||
"No LLM provider available. Configure at least one of: "
|
||||
"THREAT_HUNT_LOCAL_MODEL_PATH, THREAT_HUNT_NETWORKED_ENDPOINT, "
|
||||
"or THREAT_HUNT_ONLINE_API_KEY"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider.is_available():
|
||||
raise RuntimeError(f"{provider_type} provider not available")
|
||||
|
||||
return provider
|
||||
@@ -1,170 +0,0 @@
|
||||
"""API routes for analyst-assist agent."""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agents.core import ThreatHuntAgent, AgentContext, AgentResponse
|
||||
from app.agents.config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/agent", tags=["agent"])
|
||||
|
||||
# Global agent instance (lazy-loaded)
|
||||
_agent: ThreatHuntAgent | None = None
|
||||
|
||||
|
||||
def get_agent() -> ThreatHuntAgent:
|
||||
"""Get or create the agent instance."""
|
||||
global _agent
|
||||
if _agent is None:
|
||||
if not AgentConfig.is_agent_enabled():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Analyst-assist agent is not configured. "
|
||||
"Please configure an LLM provider.",
|
||||
)
|
||||
_agent = ThreatHuntAgent()
|
||||
return _agent
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
"""Request for agent assistance."""
|
||||
|
||||
query: str = Field(
|
||||
..., description="Analyst question or request for guidance"
|
||||
)
|
||||
dataset_name: str | None = Field(
|
||||
None, description="Name of CSV dataset being analyzed"
|
||||
)
|
||||
artifact_type: str | None = Field(
|
||||
None, description="Type of artifact (e.g., FileList, ProcessList, NetworkConnections)"
|
||||
)
|
||||
host_identifier: str | None = Field(
|
||||
None, description="Host name, IP address, or identifier"
|
||||
)
|
||||
data_summary: str | None = Field(
|
||||
None, description="Brief summary or context about the uploaded data"
|
||||
)
|
||||
conversation_history: list[dict] | None = Field(
|
||||
None, description="Previous messages for context"
|
||||
)
|
||||
|
||||
|
||||
class AssistResponse(BaseModel):
|
||||
"""Response with agent guidance."""
|
||||
|
||||
guidance: str
|
||||
confidence: float
|
||||
suggested_pivots: list[str]
|
||||
suggested_filters: list[str]
|
||||
caveats: str | None = None
|
||||
reasoning: str | None = None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/assist",
|
||||
response_model=AssistResponse,
|
||||
summary="Get analyst-assist guidance",
|
||||
description="Request guidance on CSV artifact data, analytical pivots, and hypotheses. "
|
||||
"Agent provides advisory guidance only - no execution.",
|
||||
)
|
||||
async def agent_assist(request: AssistRequest) -> AssistResponse:
|
||||
"""Provide analyst-assist guidance on artifact data.
|
||||
|
||||
The agent will:
|
||||
- Explain and interpret the provided data context
|
||||
- Suggest analytical pivots the analyst might explore
|
||||
- Suggest data filters or queries that might be useful
|
||||
- Highlight assumptions, limitations, and caveats
|
||||
|
||||
The agent will NOT:
|
||||
- Execute any tools or actions
|
||||
- Escalate findings to alerts
|
||||
- Modify any data or schema
|
||||
- Make autonomous decisions
|
||||
|
||||
Args:
|
||||
request: Assistance request with query and context
|
||||
|
||||
Returns:
|
||||
Guidance response with suggestions and reasoning
|
||||
|
||||
Raises:
|
||||
HTTPException: If agent is not configured (503) or request fails
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
|
||||
# Build context
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
dataset_name=request.dataset_name,
|
||||
artifact_type=request.artifact_type,
|
||||
host_identifier=request.host_identifier,
|
||||
data_summary=request.data_summary,
|
||||
conversation_history=request.conversation_history or [],
|
||||
)
|
||||
|
||||
# Get guidance
|
||||
response = await agent.assist(context)
|
||||
|
||||
logger.info(
|
||||
f"Agent assisted analyst with query: {request.query[:50]}... "
|
||||
f"(host: {request.host_identifier}, artifact: {request.artifact_type})"
|
||||
)
|
||||
|
||||
return AssistResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Agent error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Agent unavailable: {str(e)}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error in agent_assist: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error generating guidance. Please try again.",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Check agent health",
|
||||
description="Check if agent is configured and ready to assist.",
|
||||
)
|
||||
async def agent_health() -> dict:
|
||||
"""Check agent availability and configuration.
|
||||
|
||||
Returns:
|
||||
Health status with configuration details
|
||||
"""
|
||||
try:
|
||||
agent = get_agent()
|
||||
provider_type = agent.provider.__class__.__name__ if agent.provider else "None"
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": provider_type,
|
||||
"max_tokens": AgentConfig.MAX_RESPONSE_TOKENS,
|
||||
"reasoning_enabled": AgentConfig.ENABLE_REASONING,
|
||||
}
|
||||
except HTTPException:
|
||||
return {
|
||||
"status": "unavailable",
|
||||
"reason": "No LLM provider configured",
|
||||
"configured_providers": {
|
||||
"local": bool(AgentConfig.LOCAL_MODEL_PATH),
|
||||
"networked": bool(AgentConfig.NETWORKED_ENDPOINT),
|
||||
"online": bool(AgentConfig.ONLINE_API_KEY),
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for analyst-assist agent — v2.
|
||||
"""API routes for analyst-assist agent v2.
|
||||
|
||||
Supports quick, deep, and debate modes with streaming.
|
||||
Conversations are persisted to the database.
|
||||
@@ -6,19 +6,25 @@ Conversations are persisted to the database.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import Counter
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import Conversation, Message
|
||||
from app.db.models import Conversation, Message, Dataset, KeywordTheme
|
||||
from app.agents.core_v2 import ThreatHuntAgent, AgentContext, AgentResponse, Perspective
|
||||
from app.agents.providers_v2 import check_all_nodes
|
||||
from app.agents.registry import registry
|
||||
from app.services.sans_rag import sans_rag
|
||||
from app.services.scanner import KeywordScanner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,7 +41,7 @@ def get_agent() -> ThreatHuntAgent:
|
||||
return _agent
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# Request / Response models
|
||||
|
||||
|
||||
class AssistRequest(BaseModel):
|
||||
@@ -52,6 +58,8 @@ class AssistRequest(BaseModel):
|
||||
model_override: str | None = None
|
||||
conversation_id: str | None = Field(None, description="Persist messages to this conversation")
|
||||
hunt_id: str | None = None
|
||||
execution_preference: str = Field(default="auto", description="auto | force | off")
|
||||
learning_mode: bool = False
|
||||
|
||||
|
||||
class AssistResponseModel(BaseModel):
|
||||
@@ -66,10 +74,170 @@ class AssistResponseModel(BaseModel):
|
||||
node_used: str = ""
|
||||
latency_ms: int = 0
|
||||
perspectives: list[dict] | None = None
|
||||
execution: dict | None = None
|
||||
conversation_id: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
POLICY_THEME_NAMES = {"Adult Content", "Gambling", "Downloads / Piracy"}
|
||||
POLICY_QUERY_TERMS = {
|
||||
"policy", "violating", "violation", "browser history", "web history",
|
||||
"domain", "domains", "adult", "gambling", "piracy", "aup",
|
||||
}
|
||||
WEB_DATASET_HINTS = {
|
||||
"web", "history", "browser", "url", "visited_url", "domain", "title",
|
||||
}
|
||||
|
||||
|
||||
def _is_policy_domain_query(query: str) -> bool:
|
||||
q = (query or "").lower()
|
||||
if not q:
|
||||
return False
|
||||
score = sum(1 for t in POLICY_QUERY_TERMS if t in q)
|
||||
return score >= 2 and ("domain" in q or "history" in q or "policy" in q)
|
||||
|
||||
def _should_execute_policy_scan(request: AssistRequest) -> bool:
|
||||
pref = (request.execution_preference or "auto").strip().lower()
|
||||
if pref == "off":
|
||||
return False
|
||||
if pref == "force":
|
||||
return True
|
||||
return _is_policy_domain_query(request.query)
|
||||
|
||||
|
||||
def _extract_domain(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(text)
|
||||
if parsed.netloc:
|
||||
return parsed.netloc.lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
m = re.search(r"([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}", text)
|
||||
return m.group(0).lower() if m else None
|
||||
|
||||
|
||||
def _dataset_score(ds: Dataset) -> int:
|
||||
score = 0
|
||||
name = (ds.name or "").lower()
|
||||
cols_l = {c.lower() for c in (ds.column_schema or {}).keys()}
|
||||
norm_vals_l = {str(v).lower() for v in (ds.normalized_columns or {}).values()}
|
||||
|
||||
for h in WEB_DATASET_HINTS:
|
||||
if h in name:
|
||||
score += 2
|
||||
if h in cols_l:
|
||||
score += 3
|
||||
if h in norm_vals_l:
|
||||
score += 3
|
||||
|
||||
if "visited_url" in cols_l or "url" in cols_l:
|
||||
score += 8
|
||||
if "user" in cols_l or "username" in cols_l:
|
||||
score += 2
|
||||
if "clientid" in cols_l or "fqdn" in cols_l:
|
||||
score += 2
|
||||
if (ds.row_count or 0) > 0:
|
||||
score += 1
|
||||
|
||||
return score
|
||||
|
||||
|
||||
async def _run_policy_domain_execution(request: AssistRequest, db: AsyncSession) -> dict:
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
theme_result = await db.execute(
|
||||
select(KeywordTheme).where(
|
||||
KeywordTheme.enabled == True, # noqa: E712
|
||||
KeywordTheme.name.in_(list(POLICY_THEME_NAMES)),
|
||||
)
|
||||
)
|
||||
themes = list(theme_result.scalars().all())
|
||||
theme_ids = [t.id for t in themes]
|
||||
theme_names = [t.name for t in themes] or sorted(POLICY_THEME_NAMES)
|
||||
|
||||
ds_query = select(Dataset).where(Dataset.processing_status.in_(["completed", "ready", "processing"]))
|
||||
if request.hunt_id:
|
||||
ds_query = ds_query.where(Dataset.hunt_id == request.hunt_id)
|
||||
ds_result = await db.execute(ds_query)
|
||||
candidates = list(ds_result.scalars().all())
|
||||
|
||||
if request.dataset_name:
|
||||
needle = request.dataset_name.lower().strip()
|
||||
candidates = [d for d in candidates if needle in (d.name or "").lower()]
|
||||
|
||||
scored = sorted(
|
||||
((d, _dataset_score(d)) for d in candidates),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
selected = [d for d, s in scored if s > 0][:8]
|
||||
dataset_ids = [d.id for d in selected]
|
||||
|
||||
if not dataset_ids:
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": 0,
|
||||
"dataset_names": [],
|
||||
"total_hits": 0,
|
||||
"policy_hits": 0,
|
||||
"top_user_hosts": [],
|
||||
"top_domains": [],
|
||||
"sample_hits": [],
|
||||
"note": "No suitable browser/web-history datasets found in current scope.",
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=dataset_ids,
|
||||
theme_ids=theme_ids or None,
|
||||
scan_hunts=False,
|
||||
scan_annotations=False,
|
||||
scan_messages=False,
|
||||
)
|
||||
hits = result.get("hits", [])
|
||||
|
||||
user_host_counter = Counter()
|
||||
domain_counter = Counter()
|
||||
|
||||
for h in hits:
|
||||
user = h.get("username") or "(unknown-user)"
|
||||
host = h.get("hostname") or "(unknown-host)"
|
||||
user_host_counter[f"{user}|{host}"] += 1
|
||||
|
||||
dom = _extract_domain(h.get("matched_value"))
|
||||
if dom:
|
||||
domain_counter[dom] += 1
|
||||
|
||||
top_user_hosts = [
|
||||
{"user_host": k, "count": v}
|
||||
for k, v in user_host_counter.most_common(10)
|
||||
]
|
||||
top_domains = [
|
||||
{"domain": k, "count": v}
|
||||
for k, v in domain_counter.most_common(10)
|
||||
]
|
||||
|
||||
return {
|
||||
"mode": "policy_scan",
|
||||
"themes": theme_names,
|
||||
"datasets_scanned": len(dataset_ids),
|
||||
"dataset_names": [d.name for d in selected],
|
||||
"total_hits": int(result.get("total_hits", 0)),
|
||||
"policy_hits": int(result.get("total_hits", 0)),
|
||||
"rows_scanned": int(result.get("rows_scanned", 0)),
|
||||
"top_user_hosts": top_user_hosts,
|
||||
"top_domains": top_domains,
|
||||
"sample_hits": hits[:20],
|
||||
}
|
||||
|
||||
|
||||
# Routes
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -84,6 +252,76 @@ async def agent_assist(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> AssistResponseModel:
|
||||
try:
|
||||
# Deterministic execution mode for policy-domain investigations.
|
||||
if _should_execute_policy_scan(request):
|
||||
t0 = time.monotonic()
|
||||
exec_payload = await _run_policy_domain_execution(request, db)
|
||||
latency_ms = int((time.monotonic() - t0) * 1000)
|
||||
|
||||
policy_hits = exec_payload.get("policy_hits", 0)
|
||||
datasets_scanned = exec_payload.get("datasets_scanned", 0)
|
||||
|
||||
if policy_hits > 0:
|
||||
guidance = (
|
||||
f"Policy-violation scan complete: {policy_hits} hits across "
|
||||
f"{datasets_scanned} dataset(s). Top user/host pairs and domains are included "
|
||||
f"in execution results for triage."
|
||||
)
|
||||
confidence = 0.95
|
||||
caveats = "Keyword-based matching can include false positives; validate with full URL context."
|
||||
else:
|
||||
guidance = (
|
||||
f"No policy-violation hits found in current scope "
|
||||
f"({datasets_scanned} dataset(s) scanned)."
|
||||
)
|
||||
confidence = 0.9
|
||||
caveats = exec_payload.get("note") or "Try expanding scope to additional hunts/datasets."
|
||||
|
||||
response = AssistResponseModel(
|
||||
guidance=guidance,
|
||||
confidence=confidence,
|
||||
suggested_pivots=["username", "hostname", "domain", "dataset_name"],
|
||||
suggested_filters=[
|
||||
"theme_name in ['Adult Content','Gambling','Downloads / Piracy']",
|
||||
"username != null",
|
||||
"hostname != null",
|
||||
],
|
||||
caveats=caveats,
|
||||
reasoning=(
|
||||
"Intent matched policy-domain investigation; executed local keyword scan pipeline."
|
||||
if _is_policy_domain_query(request.query)
|
||||
else "Execution mode was forced by user preference; ran policy-domain scan pipeline."
|
||||
),
|
||||
sans_references=["SANS FOR508", "SANS SEC504"],
|
||||
model_used="execution:keyword_scanner",
|
||||
node_used="local",
|
||||
latency_ms=latency_ms,
|
||||
execution=exec_payload,
|
||||
)
|
||||
|
||||
conv_id = request.conversation_id
|
||||
if conv_id or request.hunt_id:
|
||||
conv_id = await _persist_conversation(
|
||||
db,
|
||||
conv_id,
|
||||
request,
|
||||
AgentResponse(
|
||||
guidance=response.guidance,
|
||||
confidence=response.confidence,
|
||||
suggested_pivots=response.suggested_pivots,
|
||||
suggested_filters=response.suggested_filters,
|
||||
caveats=response.caveats,
|
||||
reasoning=response.reasoning,
|
||||
sans_references=response.sans_references,
|
||||
model_used=response.model_used,
|
||||
node_used=response.node_used,
|
||||
latency_ms=response.latency_ms,
|
||||
),
|
||||
)
|
||||
response.conversation_id = conv_id
|
||||
|
||||
return response
|
||||
|
||||
agent = get_agent()
|
||||
context = AgentContext(
|
||||
query=request.query,
|
||||
@@ -97,6 +335,7 @@ async def agent_assist(
|
||||
enrichment_summary=request.enrichment_summary,
|
||||
mode=request.mode,
|
||||
model_override=request.model_override,
|
||||
learning_mode=request.learning_mode,
|
||||
)
|
||||
|
||||
response = await agent.assist(context)
|
||||
@@ -129,6 +368,7 @@ async def agent_assist(
|
||||
}
|
||||
for p in response.perspectives
|
||||
] if response.perspectives else None,
|
||||
execution=None,
|
||||
conversation_id=conv_id,
|
||||
)
|
||||
|
||||
@@ -208,7 +448,7 @@ async def list_models():
|
||||
}
|
||||
|
||||
|
||||
# ── Conversation persistence ──────────────────────────────────────────
|
||||
# Conversation persistence
|
||||
|
||||
|
||||
async def _persist_conversation(
|
||||
@@ -263,3 +503,4 @@ async def _persist_conversation(
|
||||
await db.flush()
|
||||
|
||||
return conv.id
|
||||
|
||||
|
||||
@@ -381,6 +381,10 @@ async def submit_job(
|
||||
detail=f"Invalid job_type: {job_type}. Valid: {[t.value for t in JobType]}",
|
||||
)
|
||||
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
if not job_queue.can_accept():
|
||||
raise HTTPException(status_code=429, detail="Job queue is busy. Retry shortly.")
|
||||
job = job_queue.submit(jt, **params)
|
||||
return {"job_id": job.id, "status": job.status.value, "job_type": job_type}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
"""API routes for authentication — register, login, refresh, profile."""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
# ── Request / Response models ─────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
@@ -57,7 +57,7 @@ class AuthResponse(BaseModel):
|
||||
tokens: TokenPair
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -86,7 +86,7 @@ async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
user = User(
|
||||
username=body.username,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
hashed_password=hash_password(body.password),
|
||||
display_name=body.display_name or body.username,
|
||||
role="analyst", # Default role
|
||||
)
|
||||
@@ -120,13 +120,13 @@ async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.username == body.username))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not user.password_hash:
|
||||
if not user or not user.hashed_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
)
|
||||
|
||||
if not verify_password(body.password, user.password_hash):
|
||||
if not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid username or password",
|
||||
@@ -165,7 +165,7 @@ async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)
|
||||
if token_data.type != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type — use refresh token",
|
||||
detail="Invalid token type — use refresh token",
|
||||
)
|
||||
|
||||
result = await db.execute(select(User).where(User.id == token_data.sub))
|
||||
@@ -195,3 +195,4 @@ async def get_profile(user: User = Depends(get_current_user)):
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
||||
)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.db.models import ProcessingTask
|
||||
from app.db.repositories.datasets import DatasetRepository
|
||||
from app.services.csv_parser import parse_csv_bytes, infer_column_types
|
||||
from app.services.normalizer import (
|
||||
@@ -18,15 +19,20 @@ from app.services.normalizer import (
|
||||
detect_ioc_columns,
|
||||
detect_time_range,
|
||||
)
|
||||
from app.services.artifact_classifier import classify_artifact, get_artifact_category
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
from app.services.host_inventory import inventory_cache
|
||||
from app.services.scanner import keyword_scan_cache
|
||||
|
||||
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
|
||||
|
||||
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
|
||||
|
||||
|
||||
# ── Response models ───────────────────────────────────────────────────
|
||||
# -- Response models --
|
||||
|
||||
|
||||
class DatasetSummary(BaseModel):
|
||||
@@ -43,6 +49,8 @@ class DatasetSummary(BaseModel):
|
||||
delimiter: str | None = None
|
||||
time_range_start: str | None = None
|
||||
time_range_end: str | None = None
|
||||
artifact_type: str | None = None
|
||||
processing_status: str | None = None
|
||||
hunt_id: str | None = None
|
||||
created_at: str
|
||||
|
||||
@@ -67,10 +75,13 @@ class UploadResponse(BaseModel):
|
||||
column_types: dict
|
||||
normalized_columns: dict
|
||||
ioc_columns: dict
|
||||
artifact_type: str | None = None
|
||||
processing_status: str
|
||||
jobs_queued: list[str]
|
||||
message: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
# -- Routes --
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -78,7 +89,7 @@ class UploadResponse(BaseModel):
|
||||
response_model=UploadResponse,
|
||||
summary="Upload a CSV dataset",
|
||||
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
|
||||
"IOCs auto-detected, and rows stored in the database.",
|
||||
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
|
||||
)
|
||||
async def upload_dataset(
|
||||
file: UploadFile = File(...),
|
||||
@@ -87,7 +98,7 @@ async def upload_dataset(
|
||||
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Upload and parse a CSV dataset."""
|
||||
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
|
||||
# Validate file
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No filename provided")
|
||||
@@ -136,7 +147,12 @@ async def upload_dataset(
|
||||
# Detect time range
|
||||
time_start, time_end = detect_time_range(rows, column_mapping)
|
||||
|
||||
# Store in DB
|
||||
# Classify artifact type from column headers
|
||||
artifact_type = classify_artifact(columns)
|
||||
artifact_category = get_artifact_category(artifact_type)
|
||||
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
|
||||
|
||||
# Store in DB with processing_status = "processing"
|
||||
repo = DatasetRepository(db)
|
||||
dataset = await repo.create_dataset(
|
||||
name=name or Path(file.filename).stem,
|
||||
@@ -152,6 +168,8 @@ async def upload_dataset(
|
||||
time_range_start=time_start,
|
||||
time_range_end=time_end,
|
||||
hunt_id=hunt_id,
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
)
|
||||
|
||||
await repo.bulk_insert_rows(
|
||||
@@ -162,9 +180,88 @@ async def upload_dataset(
|
||||
|
||||
logger.info(
|
||||
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns detected"
|
||||
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
|
||||
f"artifact={artifact_type}"
|
||||
)
|
||||
|
||||
# -- Queue full processing pipeline --
|
||||
jobs_queued = []
|
||||
|
||||
task_rows: list[ProcessingTask] = []
|
||||
|
||||
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
|
||||
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
|
||||
jobs_queued.append("triage")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=triage_job.id,
|
||||
stage="triage",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 2. Anomaly detection (embedding-based outlier detection)
|
||||
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
|
||||
jobs_queued.append("anomaly")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=anomaly_job.id,
|
||||
stage="anomaly",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 3. AUP keyword scan
|
||||
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
|
||||
jobs_queued.append("keyword_scan")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=kw_job.id,
|
||||
stage="keyword_scan",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 4. IOC extraction
|
||||
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
|
||||
jobs_queued.append("ioc_extract")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=ioc_job.id,
|
||||
stage="ioc_extract",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
# 5. Host inventory (network map) - requires hunt_id
|
||||
if hunt_id:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
jobs_queued.append("host_inventory")
|
||||
task_rows.append(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=inv_job.id,
|
||||
stage="host_inventory",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
|
||||
if task_rows:
|
||||
db.add_all(task_rows)
|
||||
await db.flush()
|
||||
|
||||
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
|
||||
|
||||
return UploadResponse(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
@@ -173,7 +270,10 @@ async def upload_dataset(
|
||||
column_types=column_types,
|
||||
normalized_columns=column_mapping,
|
||||
ioc_columns=ioc_columns,
|
||||
message=f"Successfully uploaded {len(rows)} rows with {len(ioc_columns)} IOC columns detected",
|
||||
artifact_type=artifact_type,
|
||||
processing_status="processing",
|
||||
jobs_queued=jobs_queued,
|
||||
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
|
||||
)
|
||||
|
||||
|
||||
@@ -208,6 +308,8 @@ async def list_datasets(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -244,6 +346,8 @@ async def get_dataset(
|
||||
delimiter=ds.delimiter,
|
||||
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
|
||||
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
|
||||
artifact_type=ds.artifact_type,
|
||||
processing_status=ds.processing_status,
|
||||
hunt_id=ds.hunt_id,
|
||||
created_at=ds.created_at.isoformat(),
|
||||
)
|
||||
@@ -292,4 +396,5 @@ async def delete_dataset(
|
||||
deleted = await repo.delete_dataset(dataset_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
keyword_scan_cache.invalidate_dataset(dataset_id)
|
||||
return {"message": "Dataset deleted", "id": dataset_id}
|
||||
|
||||
@@ -8,16 +8,15 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Hunt, Conversation, Message
|
||||
from app.db.models import Hunt, Dataset, ProcessingTask
|
||||
from app.services.job_queue import job_queue
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/hunts", tags=["hunts"])
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class HuntCreate(BaseModel):
|
||||
name: str = Field(..., max_length=256)
|
||||
description: str | None = None
|
||||
@@ -26,7 +25,7 @@ class HuntCreate(BaseModel):
|
||||
class HuntUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None # active | closed | archived
|
||||
status: str | None = None
|
||||
|
||||
|
||||
class HuntResponse(BaseModel):
|
||||
@@ -46,7 +45,18 @@ class HuntListResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
class HuntProgressResponse(BaseModel):
|
||||
hunt_id: str
|
||||
status: str
|
||||
progress_percent: float
|
||||
dataset_total: int
|
||||
dataset_completed: int
|
||||
dataset_processing: int
|
||||
dataset_errors: int
|
||||
active_jobs: int
|
||||
queued_jobs: int
|
||||
network_status: str
|
||||
stages: dict
|
||||
|
||||
|
||||
@router.post("", response_model=HuntResponse, summary="Create a new hunt")
|
||||
@@ -122,6 +132,125 @@ async def get_hunt(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{hunt_id}/progress", response_model=HuntProgressResponse, summary="Get hunt processing progress")
|
||||
async def get_hunt_progress(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
hunt = await db.get(Hunt, hunt_id)
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
ds_rows = await db.execute(
|
||||
select(Dataset.id, Dataset.processing_status)
|
||||
.where(Dataset.hunt_id == hunt_id)
|
||||
)
|
||||
datasets = ds_rows.all()
|
||||
dataset_ids = {row[0] for row in datasets}
|
||||
|
||||
dataset_total = len(datasets)
|
||||
dataset_completed = sum(1 for _, st in datasets if st == "completed")
|
||||
dataset_errors = sum(1 for _, st in datasets if st == "completed_with_errors")
|
||||
dataset_processing = max(0, dataset_total - dataset_completed - dataset_errors)
|
||||
|
||||
jobs = job_queue.list_jobs(limit=5000)
|
||||
relevant_jobs = [
|
||||
j for j in jobs
|
||||
if j.get("params", {}).get("hunt_id") == hunt_id
|
||||
or j.get("params", {}).get("dataset_id") in dataset_ids
|
||||
]
|
||||
active_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "running")
|
||||
queued_jobs_mem = sum(1 for j in relevant_jobs if j.get("status") == "queued")
|
||||
|
||||
task_rows = await db.execute(
|
||||
select(ProcessingTask.stage, ProcessingTask.status, ProcessingTask.progress)
|
||||
.where(ProcessingTask.hunt_id == hunt_id)
|
||||
)
|
||||
tasks = task_rows.all()
|
||||
|
||||
task_total = len(tasks)
|
||||
task_done = sum(1 for _, st, _ in tasks if st in ("completed", "failed", "cancelled"))
|
||||
task_running = sum(1 for _, st, _ in tasks if st == "running")
|
||||
task_queued = sum(1 for _, st, _ in tasks if st == "queued")
|
||||
task_ratio = (task_done / task_total) if task_total > 0 else None
|
||||
|
||||
active_jobs = max(active_jobs_mem, task_running)
|
||||
queued_jobs = max(queued_jobs_mem, task_queued)
|
||||
|
||||
stage_rollup: dict[str, dict] = {}
|
||||
for stage, status, progress in tasks:
|
||||
bucket = stage_rollup.setdefault(stage, {"total": 0, "done": 0, "running": 0, "queued": 0, "progress_sum": 0.0})
|
||||
bucket["total"] += 1
|
||||
if status in ("completed", "failed", "cancelled"):
|
||||
bucket["done"] += 1
|
||||
elif status == "running":
|
||||
bucket["running"] += 1
|
||||
elif status == "queued":
|
||||
bucket["queued"] += 1
|
||||
bucket["progress_sum"] += float(progress or 0.0)
|
||||
|
||||
for stage_name, bucket in stage_rollup.items():
|
||||
total = max(1, bucket["total"])
|
||||
bucket["percent"] = round(bucket["progress_sum"] / total, 1)
|
||||
|
||||
if inventory_cache.get(hunt_id) is not None:
|
||||
network_status = "ready"
|
||||
network_ratio = 1.0
|
||||
elif inventory_cache.is_building(hunt_id):
|
||||
network_status = "building"
|
||||
network_ratio = 0.5
|
||||
else:
|
||||
network_status = "none"
|
||||
network_ratio = 0.0
|
||||
|
||||
dataset_ratio = ((dataset_completed + dataset_errors) / dataset_total) if dataset_total > 0 else 1.0
|
||||
if task_ratio is None:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.85) + (network_ratio * 0.15))
|
||||
else:
|
||||
overall_ratio = min(1.0, (dataset_ratio * 0.50) + (task_ratio * 0.35) + (network_ratio * 0.15))
|
||||
progress_percent = round(overall_ratio * 100.0, 1)
|
||||
|
||||
status = "ready"
|
||||
if dataset_total == 0:
|
||||
status = "idle"
|
||||
elif progress_percent < 100:
|
||||
status = "processing"
|
||||
|
||||
stages = {
|
||||
"datasets": {
|
||||
"total": dataset_total,
|
||||
"completed": dataset_completed,
|
||||
"processing": dataset_processing,
|
||||
"errors": dataset_errors,
|
||||
"percent": round(dataset_ratio * 100.0, 1),
|
||||
},
|
||||
"network": {
|
||||
"status": network_status,
|
||||
"percent": round(network_ratio * 100.0, 1),
|
||||
},
|
||||
"jobs": {
|
||||
"active": active_jobs,
|
||||
"queued": queued_jobs,
|
||||
"total_seen": len(relevant_jobs),
|
||||
"task_total": task_total,
|
||||
"task_done": task_done,
|
||||
"task_percent": round((task_ratio or 0.0) * 100.0, 1) if task_total else None,
|
||||
},
|
||||
"task_stages": stage_rollup,
|
||||
}
|
||||
|
||||
return HuntProgressResponse(
|
||||
hunt_id=hunt_id,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
dataset_total=dataset_total,
|
||||
dataset_completed=dataset_completed,
|
||||
dataset_processing=dataset_processing,
|
||||
dataset_errors=dataset_errors,
|
||||
active_jobs=active_jobs,
|
||||
queued_jobs=queued_jobs,
|
||||
network_status=network_status,
|
||||
stages=stages,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{hunt_id}", response_model=HuntResponse, summary="Update a hunt")
|
||||
async def update_hunt(
|
||||
hunt_id: str, body: HuntUpdate, db: AsyncSession = Depends(get_db)
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
"""API routes for AUP keyword themes, keyword CRUD, and scanning."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select, func, delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import KeywordTheme, Keyword
|
||||
from app.services.scanner import KeywordScanner
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/keywords", tags=["keywords"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ThemeCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=128)
|
||||
color: str = Field(default="#9e9e9e", max_length=16)
|
||||
@@ -67,23 +63,27 @@ class KeywordBulkCreate(BaseModel):
|
||||
|
||||
|
||||
class ScanRequest(BaseModel):
|
||||
dataset_ids: list[str] | None = None # None → all datasets
|
||||
theme_ids: list[str] | None = None # None → all enabled themes
|
||||
scan_hunts: bool = True
|
||||
scan_annotations: bool = True
|
||||
scan_messages: bool = True
|
||||
dataset_ids: list[str] | None = None
|
||||
theme_ids: list[str] | None = None
|
||||
scan_hunts: bool = False
|
||||
scan_annotations: bool = False
|
||||
scan_messages: bool = False
|
||||
prefer_cache: bool = True
|
||||
force_rescan: bool = False
|
||||
|
||||
|
||||
class ScanHit(BaseModel):
|
||||
theme_name: str
|
||||
theme_color: str
|
||||
keyword: str
|
||||
source_type: str # dataset_row | hunt | annotation | message
|
||||
source_type: str
|
||||
source_id: str | int
|
||||
field: str
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class ScanResponse(BaseModel):
|
||||
@@ -92,9 +92,9 @@ class ScanResponse(BaseModel):
|
||||
themes_scanned: int
|
||||
keywords_scanned: int
|
||||
rows_scanned: int
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
cache_used: bool = False
|
||||
cache_status: str = "miss"
|
||||
cached_at: str | None = None
|
||||
|
||||
|
||||
def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
@@ -119,49 +119,58 @@ def _theme_to_out(t: KeywordTheme) -> ThemeOut:
|
||||
)
|
||||
|
||||
|
||||
# ── Theme CRUD ────────────────────────────────────────────────────────
|
||||
def _merge_cached_results(entries: list[dict], allowed_theme_names: set[str] | None = None) -> dict:
|
||||
hits: list[dict] = []
|
||||
total_rows = 0
|
||||
cached_at: str | None = None
|
||||
|
||||
for entry in entries:
|
||||
result = entry["result"]
|
||||
total_rows += int(result.get("rows_scanned", 0) or 0)
|
||||
if entry.get("built_at"):
|
||||
if not cached_at or entry["built_at"] > cached_at:
|
||||
cached_at = entry["built_at"]
|
||||
for h in result.get("hits", []):
|
||||
if allowed_theme_names is not None and h.get("theme_name") not in allowed_theme_names:
|
||||
continue
|
||||
hits.append(h)
|
||||
|
||||
return {
|
||||
"total_hits": len(hits),
|
||||
"hits": hits,
|
||||
"rows_scanned": total_rows,
|
||||
"cached_at": cached_at,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/themes", response_model=ThemeListResponse)
|
||||
async def list_themes(db: AsyncSession = Depends(get_db)):
|
||||
"""List all keyword themes with their keywords."""
|
||||
result = await db.execute(
|
||||
select(KeywordTheme).order_by(KeywordTheme.name)
|
||||
)
|
||||
result = await db.execute(select(KeywordTheme).order_by(KeywordTheme.name))
|
||||
themes = result.scalars().all()
|
||||
return ThemeListResponse(
|
||||
themes=[_theme_to_out(t) for t in themes],
|
||||
total=len(themes),
|
||||
)
|
||||
return ThemeListResponse(themes=[_theme_to_out(t) for t in themes], total=len(themes))
|
||||
|
||||
|
||||
@router.post("/themes", response_model=ThemeOut, status_code=201)
|
||||
async def create_theme(body: ThemeCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Create a new keyword theme."""
|
||||
exists = await db.scalar(
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name)
|
||||
)
|
||||
exists = await db.scalar(select(KeywordTheme.id).where(KeywordTheme.name == body.name))
|
||||
if exists:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
theme = KeywordTheme(name=body.name, color=body.color, enabled=body.enabled)
|
||||
db.add(theme)
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.put("/themes/{theme_id}", response_model=ThemeOut)
|
||||
async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update theme name, color, or enabled status."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
if body.name is not None:
|
||||
# check uniqueness
|
||||
dup = await db.scalar(
|
||||
select(KeywordTheme.id).where(
|
||||
KeywordTheme.name == body.name, KeywordTheme.id != theme_id
|
||||
)
|
||||
select(KeywordTheme.id).where(KeywordTheme.name == body.name, KeywordTheme.id != theme_id)
|
||||
)
|
||||
if dup:
|
||||
raise HTTPException(409, f"Theme '{body.name}' already exists")
|
||||
@@ -172,24 +181,21 @@ async def update_theme(theme_id: str, body: ThemeUpdate, db: AsyncSession = Depe
|
||||
theme.enabled = body.enabled
|
||||
await db.flush()
|
||||
await db.refresh(theme)
|
||||
keyword_scan_cache.clear()
|
||||
return _theme_to_out(theme)
|
||||
|
||||
|
||||
@router.delete("/themes/{theme_id}", status_code=204)
|
||||
async def delete_theme(theme_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a theme and all its keywords."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
await db.delete(theme)
|
||||
|
||||
|
||||
# ── Keyword CRUD ──────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords", response_model=KeywordOut, status_code=201)
|
||||
async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add a single keyword to a theme."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -197,6 +203,7 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
db.add(kw)
|
||||
await db.flush()
|
||||
await db.refresh(kw)
|
||||
keyword_scan_cache.clear()
|
||||
return KeywordOut(
|
||||
id=kw.id, theme_id=kw.theme_id, value=kw.value,
|
||||
is_regex=kw.is_regex, created_at=kw.created_at.isoformat(),
|
||||
@@ -205,7 +212,6 @@ async def add_keyword(theme_id: str, body: KeywordCreate, db: AsyncSession = Dep
|
||||
|
||||
@router.post("/themes/{theme_id}/keywords/bulk", response_model=dict, status_code=201)
|
||||
async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSession = Depends(get_db)):
|
||||
"""Add multiple keywords to a theme at once."""
|
||||
theme = await db.get(KeywordTheme, theme_id)
|
||||
if not theme:
|
||||
raise HTTPException(404, "Theme not found")
|
||||
@@ -217,25 +223,88 @@ async def add_keywords_bulk(theme_id: str, body: KeywordBulkCreate, db: AsyncSes
|
||||
db.add(Keyword(theme_id=theme_id, value=val, is_regex=body.is_regex))
|
||||
added += 1
|
||||
await db.flush()
|
||||
keyword_scan_cache.clear()
|
||||
return {"added": added, "theme_id": theme_id}
|
||||
|
||||
|
||||
@router.delete("/keywords/{keyword_id}", status_code=204)
|
||||
async def delete_keyword(keyword_id: int, db: AsyncSession = Depends(get_db)):
|
||||
"""Delete a single keyword."""
|
||||
kw = await db.get(Keyword, keyword_id)
|
||||
if not kw:
|
||||
raise HTTPException(404, "Keyword not found")
|
||||
await db.delete(kw)
|
||||
|
||||
|
||||
# ── Scan endpoints ────────────────────────────────────────────────────
|
||||
keyword_scan_cache.clear()
|
||||
|
||||
|
||||
@router.post("/scan", response_model=ScanResponse)
|
||||
async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
"""Run AUP keyword scan across selected data sources."""
|
||||
scanner = KeywordScanner(db)
|
||||
|
||||
if not body.dataset_ids and not body.scan_hunts and not body.scan_annotations and not body.scan_messages:
|
||||
return {
|
||||
"total_hits": 0,
|
||||
"hits": [],
|
||||
"themes_scanned": 0,
|
||||
"keywords_scanned": 0,
|
||||
"rows_scanned": 0,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
can_use_cache = (
|
||||
body.prefer_cache
|
||||
and not body.force_rescan
|
||||
and bool(body.dataset_ids)
|
||||
and not body.scan_hunts
|
||||
and not body.scan_annotations
|
||||
and not body.scan_messages
|
||||
)
|
||||
|
||||
if can_use_cache:
|
||||
themes = await scanner._load_themes(body.theme_ids)
|
||||
allowed_theme_names = {t.name for t in themes}
|
||||
keywords_scanned = sum(len(theme.keywords) for theme in themes)
|
||||
|
||||
cached_entries: list[dict] = []
|
||||
missing: list[str] = []
|
||||
for dataset_id in (body.dataset_ids or []):
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if not entry:
|
||||
missing.append(dataset_id)
|
||||
continue
|
||||
cached_entries.append({"result": entry.result, "built_at": entry.built_at})
|
||||
|
||||
if not missing and cached_entries:
|
||||
merged = _merge_cached_results(cached_entries, allowed_theme_names if body.theme_ids else None)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
if missing:
|
||||
partial = await scanner.scan(dataset_ids=missing, theme_ids=body.theme_ids)
|
||||
merged = _merge_cached_results(
|
||||
cached_entries + [{"result": partial, "built_at": None}],
|
||||
allowed_theme_names if body.theme_ids else None,
|
||||
)
|
||||
return {
|
||||
"total_hits": merged["total_hits"],
|
||||
"hits": merged["hits"],
|
||||
"themes_scanned": len(themes),
|
||||
"keywords_scanned": keywords_scanned,
|
||||
"rows_scanned": merged["rows_scanned"],
|
||||
"cache_used": len(cached_entries) > 0,
|
||||
"cache_status": "partial" if cached_entries else "miss",
|
||||
"cached_at": merged["cached_at"],
|
||||
}
|
||||
|
||||
result = await scanner.scan(
|
||||
dataset_ids=body.dataset_ids,
|
||||
theme_ids=body.theme_ids,
|
||||
@@ -243,7 +312,13 @@ async def run_scan(body: ScanRequest, db: AsyncSession = Depends(get_db)):
|
||||
scan_annotations=body.scan_annotations,
|
||||
scan_messages=body.scan_messages,
|
||||
)
|
||||
return result
|
||||
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scan/quick", response_model=ScanResponse)
|
||||
@@ -251,7 +326,22 @@ async def quick_scan(
|
||||
dataset_id: str = Query(..., description="Dataset to scan"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Quick scan a single dataset with all enabled themes."""
|
||||
entry = keyword_scan_cache.get(dataset_id)
|
||||
if entry is not None:
|
||||
result = entry.result
|
||||
return {
|
||||
**result,
|
||||
"cache_used": True,
|
||||
"cache_status": "hit",
|
||||
"cached_at": entry.built_at,
|
||||
}
|
||||
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
return result
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
return {
|
||||
**result,
|
||||
"cache_used": False,
|
||||
"cache_status": "miss",
|
||||
"cached_at": None,
|
||||
}
|
||||
|
||||
146
backend/app/api/routes/mitre.py
Normal file
146
backend/app/api/routes/mitre.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""API routes for MITRE ATT&CK coverage visualization."""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
TriageResult, HostProfile, Hypothesis, HuntReport, Dataset, Hunt
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/mitre", tags=["mitre"])
|
||||
|
||||
# Canonical MITRE ATT&CK tactics in kill-chain order
|
||||
TACTICS = [
|
||||
"Reconnaissance", "Resource Development", "Initial Access",
|
||||
"Execution", "Persistence", "Privilege Escalation",
|
||||
"Defense Evasion", "Credential Access", "Discovery",
|
||||
"Lateral Movement", "Collection", "Command and Control",
|
||||
"Exfiltration", "Impact",
|
||||
]
|
||||
|
||||
# Simplified technique-to-tactic mapping (top techniques)
|
||||
TECHNIQUE_TACTIC: dict[str, str] = {
|
||||
"T1059": "Execution", "T1059.001": "Execution", "T1059.003": "Execution",
|
||||
"T1059.005": "Execution", "T1059.006": "Execution", "T1059.007": "Execution",
|
||||
"T1053": "Persistence", "T1053.005": "Persistence",
|
||||
"T1547": "Persistence", "T1547.001": "Persistence",
|
||||
"T1543": "Persistence", "T1543.003": "Persistence",
|
||||
"T1078": "Privilege Escalation", "T1078.001": "Privilege Escalation",
|
||||
"T1078.002": "Privilege Escalation", "T1078.003": "Privilege Escalation",
|
||||
"T1055": "Privilege Escalation", "T1055.001": "Privilege Escalation",
|
||||
"T1548": "Privilege Escalation", "T1548.002": "Privilege Escalation",
|
||||
"T1070": "Defense Evasion", "T1070.001": "Defense Evasion",
|
||||
"T1070.004": "Defense Evasion",
|
||||
"T1036": "Defense Evasion", "T1036.005": "Defense Evasion",
|
||||
"T1027": "Defense Evasion", "T1140": "Defense Evasion",
|
||||
"T1218": "Defense Evasion", "T1218.011": "Defense Evasion",
|
||||
"T1003": "Credential Access", "T1003.001": "Credential Access",
|
||||
"T1110": "Credential Access", "T1558": "Credential Access",
|
||||
"T1087": "Discovery", "T1087.001": "Discovery", "T1087.002": "Discovery",
|
||||
"T1082": "Discovery", "T1083": "Discovery", "T1057": "Discovery",
|
||||
"T1018": "Discovery", "T1049": "Discovery", "T1016": "Discovery",
|
||||
"T1021": "Lateral Movement", "T1021.001": "Lateral Movement",
|
||||
"T1021.002": "Lateral Movement", "T1021.006": "Lateral Movement",
|
||||
"T1570": "Lateral Movement",
|
||||
"T1560": "Collection", "T1074": "Collection", "T1005": "Collection",
|
||||
"T1071": "Command and Control", "T1071.001": "Command and Control",
|
||||
"T1105": "Command and Control", "T1572": "Command and Control",
|
||||
"T1095": "Command and Control",
|
||||
"T1048": "Exfiltration", "T1041": "Exfiltration",
|
||||
"T1486": "Impact", "T1490": "Impact", "T1489": "Impact",
|
||||
"T1566": "Initial Access", "T1566.001": "Initial Access",
|
||||
"T1566.002": "Initial Access",
|
||||
"T1190": "Initial Access", "T1133": "Initial Access",
|
||||
"T1195": "Initial Access", "T1195.002": "Initial Access",
|
||||
}
|
||||
|
||||
|
||||
def _get_tactic(technique_id: str) -> str:
|
||||
"""Map a technique ID to its tactic."""
|
||||
tech = technique_id.strip().upper()
|
||||
if tech in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[tech]
|
||||
# Try parent technique
|
||||
if "." in tech:
|
||||
parent = tech.split(".")[0]
|
||||
if parent in TECHNIQUE_TACTIC:
|
||||
return TECHNIQUE_TACTIC[parent]
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@router.get("/coverage")
|
||||
async def get_mitre_coverage(
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Aggregate all MITRE techniques from triage, host profiles, hypotheses, and reports."""
|
||||
|
||||
techniques: dict[str, dict] = {}
|
||||
|
||||
# Collect from triage results
|
||||
triage_q = select(TriageResult)
|
||||
if hunt_id:
|
||||
triage_q = triage_q.join(Dataset).where(Dataset.hunt_id == hunt_id)
|
||||
result = await db.execute(triage_q.limit(500))
|
||||
for t in result.scalars().all():
|
||||
for tech in (t.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "triage", "risk_score": t.risk_score})
|
||||
|
||||
# Collect from host profiles
|
||||
profile_q = select(HostProfile)
|
||||
if hunt_id:
|
||||
profile_q = profile_q.where(HostProfile.hunt_id == hunt_id)
|
||||
result = await db.execute(profile_q.limit(200))
|
||||
for p in result.scalars().all():
|
||||
for tech in (p.mitre_techniques or []):
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "host_profile", "hostname": p.hostname})
|
||||
|
||||
# Collect from hypotheses
|
||||
hyp_q = select(Hypothesis)
|
||||
if hunt_id:
|
||||
hyp_q = hyp_q.where(Hypothesis.hunt_id == hunt_id)
|
||||
result = await db.execute(hyp_q.limit(200))
|
||||
for h in result.scalars().all():
|
||||
tech = h.mitre_technique
|
||||
if tech:
|
||||
if tech not in techniques:
|
||||
techniques[tech] = {"id": tech, "tactic": _get_tactic(tech), "sources": [], "count": 0}
|
||||
techniques[tech]["count"] += 1
|
||||
techniques[tech]["sources"].append({"type": "hypothesis", "title": h.title})
|
||||
|
||||
# Build tactic-grouped response
|
||||
tactic_groups: dict[str, list] = {t: [] for t in TACTICS}
|
||||
tactic_groups["Unknown"] = []
|
||||
for tech in techniques.values():
|
||||
tactic = tech["tactic"]
|
||||
if tactic not in tactic_groups:
|
||||
tactic_groups[tactic] = []
|
||||
tactic_groups[tactic].append(tech)
|
||||
|
||||
total_techniques = len(techniques)
|
||||
total_detections = sum(t["count"] for t in techniques.values())
|
||||
|
||||
return {
|
||||
"tactics": TACTICS,
|
||||
"technique_count": total_techniques,
|
||||
"detection_count": total_detections,
|
||||
"tactic_coverage": {
|
||||
t: {"techniques": techs, "count": len(techs)}
|
||||
for t, techs in tactic_groups.items()
|
||||
if techs
|
||||
},
|
||||
"all_techniques": list(techniques.values()),
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
"""Network topology API - host inventory endpoint."""
|
||||
"""Network topology API - host inventory endpoint with background caching."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.db import get_db
|
||||
from app.services.host_inventory import build_host_inventory
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
from app.services.job_queue import job_queue, JobType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
@@ -15,14 +18,158 @@ router = APIRouter(prefix="/api/network", tags=["network"])
|
||||
@router.get("/host-inventory")
|
||||
async def get_host_inventory(
|
||||
hunt_id: str = Query(..., description="Hunt ID to build inventory for"),
|
||||
force: bool = Query(False, description="Force rebuild, ignoring cache"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a deduplicated host inventory from all datasets in a hunt.
|
||||
"""Return a deduplicated host inventory for the hunt.
|
||||
|
||||
Returns unique hosts with hostname, IPs, OS, logged-in users, and
|
||||
network connections derived from netstat/connection data.
|
||||
Returns instantly from cache if available (pre-built after upload or on startup).
|
||||
If cache is cold, triggers a background build and returns 202 so the
|
||||
frontend can poll /inventory-status and re-request when ready.
|
||||
"""
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
if result["stats"]["total_hosts"] == 0:
|
||||
return result
|
||||
return result
|
||||
# Force rebuild: invalidate cache, queue background job, return 202
|
||||
if force:
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Rebuild queued"},
|
||||
)
|
||||
|
||||
# Try cache first
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is not None:
|
||||
logger.info(f"Serving cached host inventory for {hunt_id}")
|
||||
return cached
|
||||
|
||||
# Cache miss: trigger background build instead of blocking for 90+ seconds
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
logger.info(f"Cache miss for {hunt_id}, triggering background build")
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(status_code=202, content={"status": "deferred", "message": "Queue busy, retry shortly"})
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "building", "message": "Inventory is being built in the background"},
|
||||
)
|
||||
|
||||
|
||||
def _build_summary(inv: dict, top_n: int = 20) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
top_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:top_n]
|
||||
top_edges = sorted(conns, key=lambda c: c.get("count", 0), reverse=True)[:top_n]
|
||||
return {
|
||||
"stats": inv.get("stats", {}),
|
||||
"top_hosts": [
|
||||
{
|
||||
"id": h.get("id"),
|
||||
"hostname": h.get("hostname"),
|
||||
"row_count": h.get("row_count", 0),
|
||||
"ip_count": len(h.get("ips", [])),
|
||||
"user_count": len(h.get("users", [])),
|
||||
}
|
||||
for h in top_hosts
|
||||
],
|
||||
"top_edges": top_edges,
|
||||
}
|
||||
|
||||
|
||||
def _build_subgraph(inv: dict, node_id: str | None, max_hosts: int, max_edges: int) -> dict:
|
||||
hosts = inv.get("hosts", [])
|
||||
conns = inv.get("connections", [])
|
||||
|
||||
max_hosts = max(1, min(max_hosts, settings.NETWORK_SUBGRAPH_MAX_HOSTS))
|
||||
max_edges = max(1, min(max_edges, settings.NETWORK_SUBGRAPH_MAX_EDGES))
|
||||
|
||||
if node_id:
|
||||
rel_edges = [c for c in conns if c.get("source") == node_id or c.get("target") == node_id]
|
||||
rel_edges = sorted(rel_edges, key=lambda c: c.get("count", 0), reverse=True)[:max_edges]
|
||||
ids = {node_id}
|
||||
for c in rel_edges:
|
||||
ids.add(c.get("source"))
|
||||
ids.add(c.get("target"))
|
||||
rel_hosts = [h for h in hosts if h.get("id") in ids][:max_hosts]
|
||||
else:
|
||||
rel_hosts = sorted(hosts, key=lambda h: h.get("row_count", 0), reverse=True)[:max_hosts]
|
||||
allowed = {h.get("id") for h in rel_hosts}
|
||||
rel_edges = [
|
||||
c for c in sorted(conns, key=lambda c: c.get("count", 0), reverse=True)
|
||||
if c.get("source") in allowed and c.get("target") in allowed
|
||||
][:max_edges]
|
||||
|
||||
return {
|
||||
"hosts": rel_hosts,
|
||||
"connections": rel_edges,
|
||||
"stats": {
|
||||
**inv.get("stats", {}),
|
||||
"subgraph_hosts": len(rel_hosts),
|
||||
"subgraph_connections": len(rel_edges),
|
||||
"truncated": len(rel_hosts) < len(hosts) or len(rel_edges) < len(conns),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_inventory_summary(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
top_n: int = Query(20, ge=1, le=200),
|
||||
):
|
||||
"""Return a lightweight summary view for large hunts."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_summary(cached, top_n=top_n)
|
||||
|
||||
|
||||
@router.get("/subgraph")
|
||||
async def get_inventory_subgraph(
|
||||
hunt_id: str = Query(..., description="Hunt ID"),
|
||||
node_id: str | None = Query(None, description="Optional focal node"),
|
||||
max_hosts: int = Query(200, ge=1, le=5000),
|
||||
max_edges: int = Query(1500, ge=1, le=20000),
|
||||
):
|
||||
"""Return a bounded subgraph for scale-safe rendering."""
|
||||
cached = inventory_cache.get(hunt_id)
|
||||
if cached is None:
|
||||
if not inventory_cache.is_building(hunt_id):
|
||||
if job_queue.is_backlogged():
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={"status": "deferred", "message": "Queue busy, retry shortly"},
|
||||
)
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return JSONResponse(status_code=202, content={"status": "building"})
|
||||
return _build_subgraph(cached, node_id=node_id, max_hosts=max_hosts, max_edges=max_edges)
|
||||
|
||||
|
||||
@router.get("/inventory-status")
|
||||
async def get_inventory_status(
|
||||
hunt_id: str = Query(..., description="Hunt ID to check"),
|
||||
):
|
||||
"""Check whether pre-computed host inventory is ready for a hunt.
|
||||
|
||||
Returns: { status: "ready" | "building" | "none" }
|
||||
"""
|
||||
return {"hunt_id": hunt_id, "status": inventory_cache.status(hunt_id)}
|
||||
|
||||
|
||||
@router.post("/rebuild-inventory")
|
||||
async def trigger_rebuild(
|
||||
hunt_id: str = Query(..., description="Hunt to rebuild inventory for"),
|
||||
):
|
||||
"""Trigger a background rebuild of the host inventory cache."""
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
|
||||
return {"job_id": job.id, "status": "queued"}
|
||||
|
||||
217
backend/app/api/routes/playbooks.py
Normal file
217
backend/app/api/routes/playbooks.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""API routes for investigation playbooks."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Playbook, PlaybookStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/playbooks", tags=["playbooks"])
|
||||
|
||||
|
||||
# -- Request / Response schemas ---
|
||||
|
||||
class StepCreate(BaseModel):
|
||||
title: str
|
||||
description: str | None = None
|
||||
step_type: str = "manual"
|
||||
target_route: str | None = None
|
||||
|
||||
class PlaybookCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
hunt_id: str | None = None
|
||||
is_template: bool = False
|
||||
steps: list[StepCreate] = []
|
||||
|
||||
class PlaybookUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
class StepUpdate(BaseModel):
|
||||
is_completed: bool | None = None
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
# -- Default investigation templates ---
|
||||
|
||||
DEFAULT_TEMPLATES = [
|
||||
{
|
||||
"name": "Standard Threat Hunt",
|
||||
"description": "Step-by-step investigation workflow for a typical threat hunting engagement.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import CSV exports from Velociraptor or other tools", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Create Hunt", "description": "Create a new hunt and associate uploaded datasets", "step_type": "action", "target_route": "/hunts"},
|
||||
{"title": "AUP Keyword Scan", "description": "Run AUP keyword scanner for policy violations", "step_type": "analysis", "target_route": "/aup"},
|
||||
{"title": "Auto-Triage", "description": "Trigger AI triage on all datasets", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Review Triage Results", "description": "Review flagged rows and risk scores", "step_type": "review", "target_route": "/analysis"},
|
||||
{"title": "Enrich IOCs", "description": "Enrich flagged IPs, hashes, and domains via external sources", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Host Profiling", "description": "Generate deep host profiles for suspicious hosts", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Cross-Hunt Correlation", "description": "Identify shared IOCs and patterns across hunts", "step_type": "analysis", "target_route": "/correlation"},
|
||||
{"title": "Document Hypotheses", "description": "Record investigation hypotheses with MITRE mappings", "step_type": "action", "target_route": "/hypotheses"},
|
||||
{"title": "Generate Report", "description": "Generate final AI-assisted hunt report", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Incident Response Triage",
|
||||
"description": "Fast-track workflow for active incident response.",
|
||||
"steps": [
|
||||
{"title": "Upload Artifacts", "description": "Import forensic data from affected hosts", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Auto-Triage", "description": "Immediate AI triage for threat indicators", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "IOC Extraction", "description": "Extract all IOCs from flagged data", "step_type": "analysis", "target_route": "/analysis"},
|
||||
{"title": "Enrich Critical IOCs", "description": "Priority enrichment of high-risk indicators", "step_type": "analysis", "target_route": "/enrichment"},
|
||||
{"title": "Network Map", "description": "Visualize host connections and lateral movement", "step_type": "review", "target_route": "/network"},
|
||||
{"title": "Generate Situation Report", "description": "Create executive summary for incident command", "step_type": "action", "target_route": "/analysis"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# -- Routes ---
|
||||
|
||||
@router.get("")
|
||||
async def list_playbooks(
|
||||
include_templates: bool = True,
|
||||
hunt_id: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(Playbook)
|
||||
if hunt_id:
|
||||
q = q.where(Playbook.hunt_id == hunt_id)
|
||||
if not include_templates:
|
||||
q = q.where(Playbook.is_template == False)
|
||||
q = q.order_by(Playbook.created_at.desc())
|
||||
result = await db.execute(q.limit(100))
|
||||
playbooks = result.scalars().all()
|
||||
|
||||
return {"playbooks": [
|
||||
{
|
||||
"id": p.id, "name": p.name, "description": p.description,
|
||||
"is_template": p.is_template, "hunt_id": p.hunt_id,
|
||||
"status": p.status,
|
||||
"total_steps": len(p.steps),
|
||||
"completed_steps": sum(1 for s in p.steps if s.is_completed),
|
||||
"created_at": p.created_at.isoformat() if p.created_at else None,
|
||||
}
|
||||
for p in playbooks
|
||||
]}
|
||||
|
||||
|
||||
@router.get("/templates")
|
||||
async def get_templates():
|
||||
"""Return built-in investigation templates."""
|
||||
return {"templates": DEFAULT_TEMPLATES}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_playbook(body: PlaybookCreate, db: AsyncSession = Depends(get_db)):
|
||||
pb = Playbook(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
hunt_id=body.hunt_id,
|
||||
is_template=body.is_template,
|
||||
)
|
||||
db.add(pb)
|
||||
await db.flush()
|
||||
|
||||
created_steps = []
|
||||
for i, step in enumerate(body.steps):
|
||||
s = PlaybookStep(
|
||||
playbook_id=pb.id,
|
||||
order_index=i,
|
||||
title=step.title,
|
||||
description=step.description,
|
||||
step_type=step.step_type,
|
||||
target_route=step.target_route,
|
||||
)
|
||||
db.add(s)
|
||||
created_steps.append(s)
|
||||
|
||||
await db.flush()
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"hunt_id": pb.hunt_id, "is_template": pb.is_template,
|
||||
"steps": [
|
||||
{"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route, "is_completed": False}
|
||||
for s in created_steps
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{playbook_id}")
|
||||
async def get_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
return {
|
||||
"id": pb.id, "name": pb.name, "description": pb.description,
|
||||
"is_template": pb.is_template, "hunt_id": pb.hunt_id,
|
||||
"status": pb.status,
|
||||
"created_at": pb.created_at.isoformat() if pb.created_at else None,
|
||||
"steps": [
|
||||
{
|
||||
"id": s.id, "order_index": s.order_index, "title": s.title,
|
||||
"description": s.description, "step_type": s.step_type,
|
||||
"target_route": s.target_route,
|
||||
"is_completed": s.is_completed,
|
||||
"completed_at": s.completed_at.isoformat() if s.completed_at else None,
|
||||
"notes": s.notes,
|
||||
}
|
||||
for s in sorted(pb.steps, key=lambda x: x.order_index)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{playbook_id}")
|
||||
async def update_playbook(playbook_id: str, body: PlaybookUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
|
||||
if body.name is not None:
|
||||
pb.name = body.name
|
||||
if body.description is not None:
|
||||
pb.description = body.description
|
||||
if body.status is not None:
|
||||
pb.status = body.status
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{playbook_id}")
|
||||
async def delete_playbook(playbook_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Playbook).where(Playbook.id == playbook_id))
|
||||
pb = result.scalar_one_or_none()
|
||||
if not pb:
|
||||
raise HTTPException(status_code=404, detail="Playbook not found")
|
||||
await db.delete(pb)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.put("/steps/{step_id}")
|
||||
async def update_step(step_id: int, body: StepUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(PlaybookStep).where(PlaybookStep.id == step_id))
|
||||
step = result.scalar_one_or_none()
|
||||
if not step:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
if body.is_completed is not None:
|
||||
step.is_completed = body.is_completed
|
||||
step.completed_at = datetime.now(timezone.utc) if body.is_completed else None
|
||||
if body.notes is not None:
|
||||
step.notes = body.notes
|
||||
return {"status": "updated", "is_completed": step.is_completed}
|
||||
|
||||
164
backend/app/api/routes/saved_searches.py
Normal file
164
backend/app/api/routes/saved_searches.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""API routes for saved searches and bookmarked queries."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import SavedSearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/searches", tags=["saved-searches"])
|
||||
|
||||
|
||||
class SearchCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
search_type: str # "nlp_query", "ioc_search", "keyword_scan", "correlation"
|
||||
query_params: dict
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
class SearchUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
query_params: dict | None = None
|
||||
threshold: float | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_searches(
|
||||
search_type: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
q = select(SavedSearch).order_by(SavedSearch.created_at.desc())
|
||||
if search_type:
|
||||
q = q.where(SavedSearch.search_type == search_type)
|
||||
result = await db.execute(q.limit(100))
|
||||
searches = result.scalars().all()
|
||||
return {"searches": [
|
||||
{
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
for s in searches
|
||||
]}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_search(body: SearchCreate, db: AsyncSession = Depends(get_db)):
|
||||
s = SavedSearch(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
search_type=body.search_type,
|
||||
query_params=body.query_params,
|
||||
threshold=body.threshold,
|
||||
)
|
||||
db.add(s)
|
||||
await db.flush()
|
||||
return {
|
||||
"id": s.id, "name": s.name, "search_type": s.search_type,
|
||||
"query_params": s.query_params, "threshold": s.threshold,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{search_id}")
|
||||
async def get_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
return {
|
||||
"id": s.id, "name": s.name, "description": s.description,
|
||||
"search_type": s.search_type, "query_params": s.query_params,
|
||||
"threshold": s.threshold,
|
||||
"last_run_at": s.last_run_at.isoformat() if s.last_run_at else None,
|
||||
"last_result_count": s.last_result_count,
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{search_id}")
|
||||
async def update_search(search_id: str, body: SearchUpdate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
if body.name is not None:
|
||||
s.name = body.name
|
||||
if body.description is not None:
|
||||
s.description = body.description
|
||||
if body.query_params is not None:
|
||||
s.query_params = body.query_params
|
||||
if body.threshold is not None:
|
||||
s.threshold = body.threshold
|
||||
return {"status": "updated"}
|
||||
|
||||
|
||||
@router.delete("/{search_id}")
|
||||
async def delete_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
await db.delete(s)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@router.post("/{search_id}/run")
|
||||
async def run_saved_search(search_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Execute a saved search and return results with delta from last run."""
|
||||
result = await db.execute(select(SavedSearch).where(SavedSearch.id == search_id))
|
||||
s = result.scalar_one_or_none()
|
||||
if not s:
|
||||
raise HTTPException(status_code=404, detail="Saved search not found")
|
||||
|
||||
previous_count = s.last_result_count or 0
|
||||
results = []
|
||||
count = 0
|
||||
|
||||
if s.search_type == "ioc_search":
|
||||
from app.db.models import EnrichmentResult
|
||||
ioc_value = s.query_params.get("ioc_value", "")
|
||||
if ioc_value:
|
||||
q = select(EnrichmentResult).where(
|
||||
EnrichmentResult.ioc_value.contains(ioc_value)
|
||||
)
|
||||
res = await db.execute(q.limit(100))
|
||||
for er in res.scalars().all():
|
||||
results.append({
|
||||
"ioc_value": er.ioc_value, "ioc_type": er.ioc_type,
|
||||
"source": er.source, "verdict": er.verdict,
|
||||
})
|
||||
count = len(results)
|
||||
|
||||
elif s.search_type == "keyword_scan":
|
||||
from app.db.models import KeywordTheme
|
||||
res = await db.execute(select(KeywordTheme).where(KeywordTheme.enabled == True))
|
||||
themes = res.scalars().all()
|
||||
count = sum(len(t.keywords) for t in themes)
|
||||
results = [{"theme": t.name, "keyword_count": len(t.keywords)} for t in themes]
|
||||
|
||||
# Update last run metadata
|
||||
s.last_run_at = datetime.now(timezone.utc)
|
||||
s.last_result_count = count
|
||||
|
||||
delta = count - previous_count
|
||||
|
||||
return {
|
||||
"search_id": s.id, "search_name": s.name,
|
||||
"search_type": s.search_type,
|
||||
"result_count": count,
|
||||
"previous_count": previous_count,
|
||||
"delta": delta,
|
||||
"results": results[:50],
|
||||
}
|
||||
184
backend/app/api/routes/stix_export.py
Normal file
184
backend/app/api/routes/stix_export.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""STIX 2.1 export endpoint.
|
||||
|
||||
Aggregates hunt data (IOCs, techniques, host profiles, hypotheses) into a
|
||||
STIX 2.1 Bundle JSON download. No external dependencies required we
|
||||
build the JSON directly following the OASIS STIX 2.1 spec.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import (
|
||||
Hunt, Dataset, Hypothesis, TriageResult, HostProfile,
|
||||
EnrichmentResult, HuntReport,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/export", tags=["export"])
|
||||
|
||||
STIX_SPEC_VERSION = "2.1"
|
||||
|
||||
|
||||
def _stix_id(stype: str) -> str:
|
||||
return f"{stype}--{uuid.uuid4()}"
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
|
||||
def _build_identity(hunt_name: str) -> dict:
|
||||
return {
|
||||
"type": "identity",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("identity"),
|
||||
"created": _now_iso(),
|
||||
"modified": _now_iso(),
|
||||
"name": f"ThreatHunt - {hunt_name}",
|
||||
"identity_class": "system",
|
||||
}
|
||||
|
||||
|
||||
def _ioc_to_indicator(ioc_value: str, ioc_type: str, identity_id: str, verdict: str = None) -> dict:
|
||||
pattern_map = {
|
||||
"ipv4": f"[ipv4-addr:value = '{ioc_value}']",
|
||||
"ipv6": f"[ipv6-addr:value = '{ioc_value}']",
|
||||
"domain": f"[domain-name:value = '{ioc_value}']",
|
||||
"url": f"[url:value = '{ioc_value}']",
|
||||
"hash_md5": f"[file:hashes.'MD5' = '{ioc_value}']",
|
||||
"hash_sha1": f"[file:hashes.'SHA-1' = '{ioc_value}']",
|
||||
"hash_sha256": f"[file:hashes.'SHA-256' = '{ioc_value}']",
|
||||
"email": f"[email-addr:value = '{ioc_value}']",
|
||||
}
|
||||
pattern = pattern_map.get(ioc_type, f"[artifact:payload_bin = '{ioc_value}']")
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "indicator",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("indicator"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": f"{ioc_type}: {ioc_value}",
|
||||
"pattern": pattern,
|
||||
"pattern_type": "stix",
|
||||
"valid_from": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": [verdict or "suspicious"],
|
||||
}
|
||||
|
||||
|
||||
def _technique_to_attack_pattern(technique_id: str, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "attack-pattern",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("attack-pattern"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": technique_id,
|
||||
"created_by_ref": identity_id,
|
||||
"external_references": [
|
||||
{
|
||||
"source_name": "mitre-attack",
|
||||
"external_id": technique_id,
|
||||
"url": f"https://attack.mitre.org/techniques/{technique_id.replace('.', '/')}/",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _hypothesis_to_report(hyp, identity_id: str) -> dict:
|
||||
now = _now_iso()
|
||||
return {
|
||||
"type": "report",
|
||||
"spec_version": STIX_SPEC_VERSION,
|
||||
"id": _stix_id("report"),
|
||||
"created": now,
|
||||
"modified": now,
|
||||
"name": hyp.title,
|
||||
"description": hyp.description or "",
|
||||
"published": now,
|
||||
"created_by_ref": identity_id,
|
||||
"labels": ["threat-hunt-hypothesis"],
|
||||
"object_refs": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/stix/{hunt_id}")
|
||||
async def export_stix(hunt_id: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Export hunt data as a STIX 2.1 Bundle JSON file."""
|
||||
# Fetch hunt
|
||||
hunt = (await db.execute(select(Hunt).where(Hunt.id == hunt_id))).scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(404, "Hunt not found")
|
||||
|
||||
identity = _build_identity(hunt.name)
|
||||
objects: list[dict] = [identity]
|
||||
seen_techniques: set[str] = set()
|
||||
seen_iocs: set[str] = set()
|
||||
|
||||
# Gather IOCs from enrichment results for hunt's datasets
|
||||
datasets_q = await db.execute(select(Dataset.id).where(Dataset.hunt_id == hunt_id))
|
||||
ds_ids = [r[0] for r in datasets_q.all()]
|
||||
|
||||
if ds_ids:
|
||||
enrichments = (await db.execute(
|
||||
select(EnrichmentResult).where(EnrichmentResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for e in enrichments:
|
||||
key = f"{e.ioc_type}:{e.ioc_value}"
|
||||
if key not in seen_iocs:
|
||||
seen_iocs.add(key)
|
||||
objects.append(_ioc_to_indicator(e.ioc_value, e.ioc_type, identity["id"], e.verdict))
|
||||
|
||||
# Gather techniques from triage results
|
||||
triages = (await db.execute(
|
||||
select(TriageResult).where(TriageResult.dataset_id.in_(ds_ids))
|
||||
)).scalars().all()
|
||||
for t in triages:
|
||||
for tech in (t.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather techniques from host profiles
|
||||
profiles = (await db.execute(
|
||||
select(HostProfile).where(HostProfile.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for p in profiles:
|
||||
for tech in (p.mitre_techniques or []):
|
||||
tid = tech if isinstance(tech, str) else tech.get("technique_id", str(tech))
|
||||
if tid not in seen_techniques:
|
||||
seen_techniques.add(tid)
|
||||
objects.append(_technique_to_attack_pattern(tid, identity["id"]))
|
||||
|
||||
# Gather hypotheses
|
||||
hypos = (await db.execute(
|
||||
select(Hypothesis).where(Hypothesis.hunt_id == hunt_id)
|
||||
)).scalars().all()
|
||||
for h in hypos:
|
||||
objects.append(_hypothesis_to_report(h, identity["id"]))
|
||||
if h.mitre_technique and h.mitre_technique not in seen_techniques:
|
||||
seen_techniques.add(h.mitre_technique)
|
||||
objects.append(_technique_to_attack_pattern(h.mitre_technique, identity["id"]))
|
||||
|
||||
bundle = {
|
||||
"type": "bundle",
|
||||
"id": _stix_id("bundle"),
|
||||
"objects": objects,
|
||||
}
|
||||
|
||||
filename = f"threathunt-{hunt.name.replace(' ', '_')}-stix.json"
|
||||
return Response(
|
||||
content=json.dumps(bundle, indent=2),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
128
backend/app/api/routes/timeline.py
Normal file
128
backend/app/api/routes/timeline.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""API routes for forensic timeline visualization."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.db.models import Dataset, DatasetRow, Hunt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/timeline", tags=["timeline"])
|
||||
|
||||
|
||||
def _parse_timestamp(val: str | None) -> str | None:
|
||||
"""Try to parse a timestamp string, return ISO format or None."""
|
||||
if not val:
|
||||
return None
|
||||
val = str(val).strip()
|
||||
if not val:
|
||||
return None
|
||||
# Try common formats
|
||||
for fmt in [
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ",
|
||||
"%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M:%S", "%m/%d/%Y %H:%M:%S",
|
||||
]:
|
||||
try:
|
||||
return datetime.strptime(val, fmt).isoformat() + "Z"
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
# Columns likely to contain timestamps
|
||||
TIME_COLUMNS = {
|
||||
"timestamp", "time", "datetime", "date", "created", "modified",
|
||||
"eventtime", "event_time", "start_time", "end_time",
|
||||
"lastmodified", "last_modified", "created_at", "updated_at",
|
||||
"mtime", "atime", "ctime", "btime",
|
||||
"timecreated", "timegenerated", "sourcetime",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/hunt/{hunt_id}")
|
||||
async def get_hunt_timeline(
|
||||
hunt_id: str,
|
||||
limit: int = 2000,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Build a timeline of events across all datasets in a hunt."""
|
||||
result = await db.execute(select(Hunt).where(Hunt.id == hunt_id))
|
||||
hunt = result.scalar_one_or_none()
|
||||
if not hunt:
|
||||
raise HTTPException(status_code=404, detail="Hunt not found")
|
||||
|
||||
result = await db.execute(select(Dataset).where(Dataset.hunt_id == hunt_id))
|
||||
datasets = result.scalars().all()
|
||||
if not datasets:
|
||||
return {"hunt_id": hunt_id, "events": [], "datasets": []}
|
||||
|
||||
events = []
|
||||
dataset_info = []
|
||||
|
||||
for ds in datasets:
|
||||
artifact_type = getattr(ds, "artifact_type", None) or "Unknown"
|
||||
dataset_info.append({
|
||||
"id": ds.id, "name": ds.name, "artifact_type": artifact_type,
|
||||
"row_count": ds.row_count,
|
||||
})
|
||||
|
||||
# Find time columns for this dataset
|
||||
schema = ds.column_schema or {}
|
||||
time_cols = []
|
||||
for col in (ds.normalized_columns or {}).values():
|
||||
if col.lower() in TIME_COLUMNS:
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
for col in schema:
|
||||
if col.lower() in TIME_COLUMNS or "time" in col.lower() or "date" in col.lower():
|
||||
time_cols.append(col)
|
||||
if not time_cols:
|
||||
continue
|
||||
|
||||
# Fetch rows
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.limit(limit // max(len(datasets), 1))
|
||||
)
|
||||
for r in rows_result.scalars().all():
|
||||
data = r.normalized_data or r.data
|
||||
ts = None
|
||||
for tc in time_cols:
|
||||
ts = _parse_timestamp(data.get(tc))
|
||||
if ts:
|
||||
break
|
||||
if ts:
|
||||
hostname = data.get("hostname") or data.get("Hostname") or data.get("Fqdn") or ""
|
||||
process = data.get("process_name") or data.get("Name") or data.get("ProcessName") or ""
|
||||
summary = data.get("command_line") or data.get("CommandLine") or data.get("Details") or ""
|
||||
events.append({
|
||||
"timestamp": ts,
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": ds.name,
|
||||
"artifact_type": artifact_type,
|
||||
"row_index": r.row_index,
|
||||
"hostname": str(hostname)[:128],
|
||||
"process": str(process)[:128],
|
||||
"summary": str(summary)[:256],
|
||||
"data": {k: str(v)[:100] for k, v in list(data.items())[:8]},
|
||||
})
|
||||
|
||||
# Sort by timestamp
|
||||
events.sort(key=lambda e: e["timestamp"])
|
||||
|
||||
return {
|
||||
"hunt_id": hunt_id,
|
||||
"hunt_name": hunt.name,
|
||||
"event_count": len(events),
|
||||
"datasets": dataset_info,
|
||||
"events": events[:limit],
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Application configuration — single source of truth for all settings.
|
||||
"""Application configuration - single source of truth for all settings.
|
||||
|
||||
Loads from environment variables with sensible defaults for local dev.
|
||||
"""
|
||||
@@ -13,12 +13,12 @@ from pydantic import Field
|
||||
class AppConfig(BaseSettings):
|
||||
"""Central configuration for the entire ThreatHunt application."""
|
||||
|
||||
# ── General ────────────────────────────────────────────────────────
|
||||
# -- General --------------------------------------------------------
|
||||
APP_NAME: str = "ThreatHunt"
|
||||
APP_VERSION: str = "0.3.0"
|
||||
DEBUG: bool = Field(default=False, description="Enable debug mode")
|
||||
|
||||
# ── Database ───────────────────────────────────────────────────────
|
||||
# -- Database -------------------------------------------------------
|
||||
DATABASE_URL: str = Field(
|
||||
default="sqlite+aiosqlite:///./threathunt.db",
|
||||
description="Async SQLAlchemy database URL. "
|
||||
@@ -26,17 +26,17 @@ class AppConfig(BaseSettings):
|
||||
"postgresql+asyncpg://user:pass@host/db for production.",
|
||||
)
|
||||
|
||||
# ── CORS ───────────────────────────────────────────────────────────
|
||||
# -- CORS -----------------------------------------------------------
|
||||
ALLOWED_ORIGINS: str = Field(
|
||||
default="http://localhost:3000,http://localhost:8000",
|
||||
description="Comma-separated list of allowed CORS origins",
|
||||
)
|
||||
|
||||
# ── File uploads ───────────────────────────────────────────────────
|
||||
# -- File uploads ---------------------------------------------------
|
||||
MAX_UPLOAD_SIZE_MB: int = Field(default=500, description="Max CSV upload in MB")
|
||||
UPLOAD_DIR: str = Field(default="./uploads", description="Directory for uploaded files")
|
||||
|
||||
# ── LLM Cluster — Wile & Roadrunner ────────────────────────────────
|
||||
# -- LLM Cluster - Wile & Roadrunner --------------------------------
|
||||
OPENWEBUI_URL: str = Field(
|
||||
default="https://ai.guapo613.beer",
|
||||
description="Open WebUI cluster endpoint (OpenAI-compatible API)",
|
||||
@@ -58,7 +58,7 @@ class AppConfig(BaseSettings):
|
||||
default=11434, description="Ollama port on Roadrunner"
|
||||
)
|
||||
|
||||
# ── LLM Routing defaults ──────────────────────────────────────────
|
||||
# -- LLM Routing defaults ------------------------------------------
|
||||
DEFAULT_FAST_MODEL: str = Field(
|
||||
default="llama3.1:latest",
|
||||
description="Default model for quick chat / simple queries",
|
||||
@@ -80,18 +80,18 @@ class AppConfig(BaseSettings):
|
||||
description="Default embedding model",
|
||||
)
|
||||
|
||||
# ── Agent behaviour ───────────────────────────────────────────────
|
||||
# -- Agent behaviour ------------------------------------------------
|
||||
AGENT_MAX_TOKENS: int = Field(default=2048, description="Max tokens per agent response")
|
||||
AGENT_TEMPERATURE: float = Field(default=0.3, description="LLM temperature for guidance")
|
||||
AGENT_HISTORY_LENGTH: int = Field(default=10, description="Messages to keep in context")
|
||||
FILTER_SENSITIVE_DATA: bool = Field(default=True, description="Redact sensitive patterns")
|
||||
|
||||
# ── Enrichment API keys ───────────────────────────────────────────
|
||||
# -- Enrichment API keys --------------------------------------------
|
||||
VIRUSTOTAL_API_KEY: str = Field(default="", description="VirusTotal API key")
|
||||
ABUSEIPDB_API_KEY: str = Field(default="", description="AbuseIPDB API key")
|
||||
SHODAN_API_KEY: str = Field(default="", description="Shodan API key")
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────
|
||||
# -- Auth -----------------------------------------------------------
|
||||
JWT_SECRET: str = Field(
|
||||
default="CHANGE-ME-IN-PRODUCTION-USE-A-REAL-SECRET",
|
||||
description="Secret for JWT signing",
|
||||
@@ -99,6 +99,73 @@ class AppConfig(BaseSettings):
|
||||
JWT_ACCESS_TOKEN_MINUTES: int = Field(default=60, description="Access token lifetime")
|
||||
JWT_REFRESH_TOKEN_DAYS: int = Field(default=7, description="Refresh token lifetime")
|
||||
|
||||
# -- Triage settings ------------------------------------------------
|
||||
TRIAGE_BATCH_SIZE: int = Field(default=25, description="Rows per triage LLM batch")
|
||||
TRIAGE_MAX_SUSPICIOUS_ROWS: int = Field(
|
||||
default=200, description="Stop triage after this many suspicious rows"
|
||||
)
|
||||
TRIAGE_ESCALATION_THRESHOLD: float = Field(
|
||||
default=5.0, description="Risk score threshold for escalation counting"
|
||||
)
|
||||
|
||||
# -- Host profiler settings -----------------------------------------
|
||||
HOST_PROFILE_CONCURRENCY: int = Field(
|
||||
default=3, description="Max concurrent host profile LLM calls"
|
||||
)
|
||||
|
||||
# -- Scanner settings -----------------------------------------------
|
||||
SCANNER_BATCH_SIZE: int = Field(default=500, description="Rows per scanner batch")
|
||||
SCANNER_MAX_ROWS_PER_SCAN: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget for a single AUP scan request (0 = unlimited)",
|
||||
)
|
||||
|
||||
# -- Job queue settings ----------------------------------------------
|
||||
JOB_QUEUE_MAX_BACKLOG: int = Field(
|
||||
default=2000, description="Soft cap for queued background jobs"
|
||||
)
|
||||
JOB_QUEUE_RETAIN_COMPLETED: int = Field(
|
||||
default=3000, description="Maximum completed/failed jobs to retain in memory"
|
||||
)
|
||||
JOB_QUEUE_CLEANUP_INTERVAL_SECONDS: int = Field(
|
||||
default=60, description="How often to run in-memory job cleanup"
|
||||
)
|
||||
JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS: int = Field(
|
||||
default=3600, description="Age threshold for in-memory completed job cleanup"
|
||||
)
|
||||
|
||||
# -- Startup throttling ------------------------------------------------
|
||||
STARTUP_WARMUP_MAX_HUNTS: int = Field(
|
||||
default=5, description="Max hunts to warm inventory cache for at startup"
|
||||
)
|
||||
STARTUP_REPROCESS_MAX_DATASETS: int = Field(
|
||||
default=25, description="Max unprocessed datasets to enqueue at startup"
|
||||
)
|
||||
STARTUP_RECONCILE_STALE_TASKS: bool = Field(
|
||||
default=True,
|
||||
description="Mark stale queued/running processing tasks as failed on startup",
|
||||
)
|
||||
|
||||
# -- Network API scale guards -----------------------------------------
|
||||
NETWORK_SUBGRAPH_MAX_HOSTS: int = Field(
|
||||
default=400, description="Hard cap for hosts returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_SUBGRAPH_MAX_EDGES: int = Field(
|
||||
default=3000, description="Hard cap for edges returned by network subgraph endpoint"
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_ROWS_PER_DATASET: int = Field(
|
||||
default=5000,
|
||||
description="Row budget per dataset when building host inventory (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_TOTAL_ROWS: int = Field(
|
||||
default=120000,
|
||||
description="Global row budget across all datasets for host inventory build (0 = unlimited)",
|
||||
)
|
||||
NETWORK_INVENTORY_MAX_CONNECTIONS: int = Field(
|
||||
default=120000,
|
||||
description="Max unique connection tuples retained during host inventory build",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "TH_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
@property
|
||||
@@ -119,3 +186,4 @@ class AppConfig(BaseSettings):
|
||||
|
||||
|
||||
settings = AppConfig()
|
||||
|
||||
|
||||
@@ -21,9 +21,14 @@ _engine_kwargs: dict = dict(
|
||||
)
|
||||
|
||||
if _is_sqlite:
|
||||
_engine_kwargs["connect_args"] = {"timeout": 30}
|
||||
_engine_kwargs["pool_size"] = 1
|
||||
_engine_kwargs["max_overflow"] = 0
|
||||
_engine_kwargs["connect_args"] = {"timeout": 60, "check_same_thread": False}
|
||||
# NullPool: each session gets its own connection.
|
||||
# Combined with WAL mode, this allows concurrent reads while a write is in progress.
|
||||
from sqlalchemy.pool import NullPool
|
||||
_engine_kwargs["poolclass"] = NullPool
|
||||
else:
|
||||
_engine_kwargs["pool_size"] = 5
|
||||
_engine_kwargs["max_overflow"] = 10
|
||||
|
||||
engine = create_async_engine(settings.DATABASE_URL, **_engine_kwargs)
|
||||
|
||||
@@ -34,7 +39,7 @@ def _set_sqlite_pragmas(dbapi_conn, connection_record):
|
||||
if _is_sqlite:
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
cursor.execute("PRAGMA busy_timeout=30000")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
@@ -46,6 +51,10 @@ async_session_factory = async_sessionmaker(
|
||||
)
|
||||
|
||||
|
||||
# Alias expected by other modules
|
||||
async_session = async_session_factory
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all ORM models."""
|
||||
pass
|
||||
@@ -71,5 +80,5 @@ async def init_db() -> None:
|
||||
|
||||
|
||||
async def dispose_db() -> None:
|
||||
"""Dispose of the engine connection pool."""
|
||||
await engine.dispose()
|
||||
"""Dispose of the engine on shutdown."""
|
||||
await engine.dispose()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
"""SQLAlchemy ORM models for ThreatHunt.
|
||||
|
||||
All persistent entities: datasets, hunts, conversations, annotations,
|
||||
hypotheses, enrichment results, users, and AI analysis tables.
|
||||
@@ -43,6 +43,7 @@ class User(Base):
|
||||
hashed_password: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
role: Mapped[str] = mapped_column(String(16), default="analyst")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
display_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
hunts: Mapped[list["Hunt"]] = relationship(back_populates="owner", lazy="selectin")
|
||||
@@ -399,4 +400,108 @@ class AnomalyResult(Base):
|
||||
cluster_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
is_outlier: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
|
||||
# -- Persistent Processing Tasks (Phase 2) ---
|
||||
|
||||
class ProcessingTask(Base):
|
||||
__tablename__ = "processing_tasks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
dataset_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("datasets.id", ondelete="CASCADE"), nullable=True, index=True
|
||||
)
|
||||
job_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="queued", index=True)
|
||||
progress: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_processing_tasks_hunt_stage", "hunt_id", "stage"),
|
||||
Index("ix_processing_tasks_dataset_stage", "dataset_id", "stage"),
|
||||
)
|
||||
|
||||
|
||||
# -- Playbook / Investigation Templates (Feature 3) ---
|
||||
|
||||
class Playbook(Base):
|
||||
__tablename__ = "playbooks"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_by: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
is_template: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
hunt_id: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("hunts.id"), nullable=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(20), default="active")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=_utcnow, onupdate=_utcnow
|
||||
)
|
||||
|
||||
steps: Mapped[list["PlaybookStep"]] = relationship(
|
||||
back_populates="playbook", lazy="selectin", cascade="all, delete-orphan",
|
||||
order_by="PlaybookStep.order_index",
|
||||
)
|
||||
|
||||
|
||||
class PlaybookStep(Base):
|
||||
__tablename__ = "playbook_steps"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
playbook_id: Mapped[str] = mapped_column(
|
||||
String(32), ForeignKey("playbooks.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
order_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
step_type: Mapped[str] = mapped_column(String(32), default="manual")
|
||||
target_route: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
is_completed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
playbook: Mapped["Playbook"] = relationship(back_populates="steps")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_playbook_steps_playbook", "playbook_id"),
|
||||
)
|
||||
|
||||
|
||||
# -- Saved Searches (Feature 5) ---
|
||||
|
||||
class SavedSearch(Base):
|
||||
__tablename__ = "saved_searches"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(32), primary_key=True, default=_new_id)
|
||||
name: Mapped[str] = mapped_column(String(256), nullable=False, index=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
search_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
query_params: Mapped[dict] = mapped_column(JSON, nullable=False)
|
||||
threshold: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
created_by: Mapped[Optional[str]] = mapped_column(
|
||||
String(32), ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_result_count: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=_utcnow)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_saved_searches_type", "search_type"),
|
||||
)
|
||||
|
||||
@@ -25,6 +25,11 @@ from app.api.routes.auth import router as auth_router
|
||||
from app.api.routes.keywords import router as keywords_router
|
||||
from app.api.routes.analysis import router as analysis_router
|
||||
from app.api.routes.network import router as network_router
|
||||
from app.api.routes.mitre import router as mitre_router
|
||||
from app.api.routes.timeline import router as timeline_router
|
||||
from app.api.routes.playbooks import router as playbooks_router
|
||||
from app.api.routes.saved_searches import router as searches_router
|
||||
from app.api.routes.stix_export import router as stix_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,13 +52,80 @@ async def lifespan(app: FastAPI):
|
||||
await seed_defaults(seed_db)
|
||||
logger.info("AUP keyword defaults checked")
|
||||
|
||||
# Start job queue (Phase 10)
|
||||
from app.services.job_queue import job_queue, register_all_handlers
|
||||
# Start job queue
|
||||
from app.services.job_queue import (
|
||||
job_queue,
|
||||
register_all_handlers,
|
||||
reconcile_stale_processing_tasks,
|
||||
JobType,
|
||||
)
|
||||
|
||||
if settings.STARTUP_RECONCILE_STALE_TASKS:
|
||||
reconciled = await reconcile_stale_processing_tasks()
|
||||
if reconciled:
|
||||
logger.info("Startup reconciliation marked %d stale tasks", reconciled)
|
||||
|
||||
register_all_handlers()
|
||||
await job_queue.start()
|
||||
logger.info("Job queue started (%d workers)", job_queue._max_workers)
|
||||
|
||||
# Start load balancer health loop (Phase 10)
|
||||
# Pre-warm host inventory cache for existing hunts
|
||||
from app.services.host_inventory import inventory_cache
|
||||
async with async_session_factory() as warm_db:
|
||||
from sqlalchemy import select, func
|
||||
from app.db.models import Hunt, Dataset
|
||||
stmt = (
|
||||
select(Hunt.id)
|
||||
.join(Dataset, Dataset.hunt_id == Hunt.id)
|
||||
.group_by(Hunt.id)
|
||||
.having(func.count(Dataset.id) > 0)
|
||||
)
|
||||
result = await warm_db.execute(stmt)
|
||||
hunt_ids = [row[0] for row in result.all()]
|
||||
warm_hunts = hunt_ids[: settings.STARTUP_WARMUP_MAX_HUNTS]
|
||||
for hid in warm_hunts:
|
||||
job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hid)
|
||||
if warm_hunts:
|
||||
logger.info(f"Queued host inventory warm-up for {len(warm_hunts)} hunts (total hunts with data: {len(hunt_ids)})")
|
||||
|
||||
# Check which datasets still need processing
|
||||
# (no anomaly results = never fully processed)
|
||||
async with async_session_factory() as reprocess_db:
|
||||
from sqlalchemy import select, exists
|
||||
from app.db.models import Dataset, AnomalyResult
|
||||
# Find datasets that have zero anomaly results (pipeline never ran or failed)
|
||||
has_anomaly = (
|
||||
select(AnomalyResult.id)
|
||||
.where(AnomalyResult.dataset_id == Dataset.id)
|
||||
.limit(1)
|
||||
.correlate(Dataset)
|
||||
.exists()
|
||||
)
|
||||
stmt = select(Dataset.id).where(~has_anomaly)
|
||||
result = await reprocess_db.execute(stmt)
|
||||
unprocessed_ids = [row[0] for row in result.all()]
|
||||
|
||||
if unprocessed_ids:
|
||||
to_reprocess = unprocessed_ids[: settings.STARTUP_REPROCESS_MAX_DATASETS]
|
||||
for ds_id in to_reprocess:
|
||||
job_queue.submit(JobType.TRIAGE, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.ANOMALY, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=ds_id)
|
||||
job_queue.submit(JobType.IOC_EXTRACT, dataset_id=ds_id)
|
||||
logger.info(f"Queued processing pipeline for {len(to_reprocess)} datasets at startup (unprocessed total: {len(unprocessed_ids)})")
|
||||
async with async_session_factory() as update_db:
|
||||
from sqlalchemy import update
|
||||
from app.db.models import Dataset
|
||||
await update_db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id.in_(to_reprocess))
|
||||
.values(processing_status="processing")
|
||||
)
|
||||
await update_db.commit()
|
||||
else:
|
||||
logger.info("All datasets already processed - skipping startup pipeline")
|
||||
|
||||
# Start load balancer health loop
|
||||
from app.services.load_balancer import lb
|
||||
await lb.start_health_loop(interval=30.0)
|
||||
logger.info("Load balancer health loop started")
|
||||
@@ -61,12 +133,10 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
logger.info("Shutting down ...")
|
||||
# Stop job queue
|
||||
from app.services.job_queue import job_queue as jq
|
||||
await jq.stop()
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
# Stop load balancer
|
||||
from app.services.load_balancer import lb as _lb
|
||||
await _lb.stop_health_loop()
|
||||
logger.info("Load balancer stopped")
|
||||
@@ -106,6 +176,11 @@ app.include_router(reports_router)
|
||||
app.include_router(keywords_router)
|
||||
app.include_router(analysis_router)
|
||||
app.include_router(network_router)
|
||||
app.include_router(mitre_router)
|
||||
app.include_router(timeline_router)
|
||||
app.include_router(playbooks_router)
|
||||
app.include_router(searches_router)
|
||||
app.include_router(stix_router)
|
||||
|
||||
|
||||
@app.get("/", tags=["health"])
|
||||
@@ -120,4 +195,4 @@ async def root():
|
||||
"roadrunner": settings.roadrunner_url,
|
||||
"openwebui": settings.OPENWEBUI_URL,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import Dataset, DatasetRow
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +80,55 @@ def _extract_username(raw: str) -> str:
|
||||
return name or ''
|
||||
|
||||
|
||||
|
||||
|
||||
# In-memory host inventory cache
|
||||
# Pre-computed results stored per hunt_id, built in background after upload.
|
||||
|
||||
import time as _time
|
||||
|
||||
class _InventoryCache:
|
||||
"""Simple in-memory cache for pre-computed host inventories."""
|
||||
|
||||
def __init__(self):
|
||||
self._data: dict[str, dict] = {} # hunt_id -> result dict
|
||||
self._timestamps: dict[str, float] = {} # hunt_id -> epoch
|
||||
self._building: set[str] = set() # hunt_ids currently being built
|
||||
|
||||
def get(self, hunt_id: str) -> dict | None:
|
||||
"""Return cached result if present. Never expires; only invalidated on new upload."""
|
||||
return self._data.get(hunt_id)
|
||||
|
||||
def put(self, hunt_id: str, result: dict):
|
||||
self._data[hunt_id] = result
|
||||
self._timestamps[hunt_id] = _time.time()
|
||||
self._building.discard(hunt_id)
|
||||
logger.info(f"Cached host inventory for hunt {hunt_id} "
|
||||
f"({result['stats']['total_hosts']} hosts)")
|
||||
|
||||
def invalidate(self, hunt_id: str):
|
||||
self._data.pop(hunt_id, None)
|
||||
self._timestamps.pop(hunt_id, None)
|
||||
|
||||
def is_building(self, hunt_id: str) -> bool:
|
||||
return hunt_id in self._building
|
||||
|
||||
def set_building(self, hunt_id: str):
|
||||
self._building.add(hunt_id)
|
||||
|
||||
def clear_building(self, hunt_id: str):
|
||||
self._building.discard(hunt_id)
|
||||
|
||||
def status(self, hunt_id: str) -> str:
|
||||
if hunt_id in self._building:
|
||||
return "building"
|
||||
if hunt_id in self._data:
|
||||
return "ready"
|
||||
return "none"
|
||||
|
||||
|
||||
inventory_cache = _InventoryCache()
|
||||
|
||||
def _infer_os(fqdn: str) -> str:
|
||||
u = fqdn.upper()
|
||||
if 'W10-' in u or 'WIN10' in u:
|
||||
@@ -151,33 +201,61 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
}}
|
||||
|
||||
hosts: dict[str, dict] = {} # fqdn -> host record
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
ip_to_host: dict[str, str] = {} # local-ip -> fqdn
|
||||
connections: dict[tuple, int] = defaultdict(int)
|
||||
total_rows = 0
|
||||
ds_with_hosts = 0
|
||||
sampled_dataset_count = 0
|
||||
total_row_budget = max(0, int(settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS))
|
||||
max_connections = max(0, int(settings.NETWORK_INVENTORY_MAX_CONNECTIONS))
|
||||
global_budget_reached = False
|
||||
dropped_connections = 0
|
||||
|
||||
for ds in all_datasets:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
cols = _identify_columns(ds)
|
||||
if not cols['fqdn'] and not cols['host_id']:
|
||||
continue
|
||||
ds_with_hosts += 1
|
||||
|
||||
batch_size = 5000
|
||||
offset = 0
|
||||
max_rows_per_dataset = max(0, int(settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET))
|
||||
rows_scanned_this_dataset = 0
|
||||
sampled_dataset = False
|
||||
last_row_index = -1
|
||||
|
||||
while True:
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
rr = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds.id)
|
||||
.where(DatasetRow.row_index > last_row_index)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset).limit(batch_size)
|
||||
.limit(batch_size)
|
||||
)
|
||||
rows = rr.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for ro in rows:
|
||||
if max_rows_per_dataset and rows_scanned_this_dataset >= max_rows_per_dataset:
|
||||
sampled_dataset = True
|
||||
break
|
||||
if total_row_budget and total_rows >= total_row_budget:
|
||||
sampled_dataset = True
|
||||
global_budget_reached = True
|
||||
break
|
||||
|
||||
data = ro.data or {}
|
||||
total_rows += 1
|
||||
rows_scanned_this_dataset += 1
|
||||
|
||||
fqdn = ''
|
||||
for c in cols['fqdn']:
|
||||
@@ -239,12 +317,33 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
rport = _clean(data.get(pc))
|
||||
if rport:
|
||||
break
|
||||
connections[(host_key, rip, rport)] += 1
|
||||
conn_key = (host_key, rip, rport)
|
||||
if max_connections and len(connections) >= max_connections and conn_key not in connections:
|
||||
dropped_connections += 1
|
||||
continue
|
||||
connections[conn_key] += 1
|
||||
|
||||
offset += batch_size
|
||||
if sampled_dataset:
|
||||
sampled_dataset_count += 1
|
||||
logger.info(
|
||||
"Host inventory sampling for dataset %s (%d rows scanned)",
|
||||
ds.id,
|
||||
rows_scanned_this_dataset,
|
||||
)
|
||||
break
|
||||
|
||||
last_row_index = rows[-1].row_index
|
||||
if len(rows) < batch_size:
|
||||
break
|
||||
|
||||
if global_budget_reached:
|
||||
logger.info(
|
||||
"Host inventory global row budget reached for hunt %s at %d rows",
|
||||
hunt_id,
|
||||
total_rows,
|
||||
)
|
||||
break
|
||||
|
||||
# Post-process hosts
|
||||
for h in hosts.values():
|
||||
if not h['os'] and h['fqdn']:
|
||||
@@ -286,5 +385,12 @@ async def build_host_inventory(hunt_id: str, db: AsyncSession) -> dict:
|
||||
"total_rows_scanned": total_rows,
|
||||
"hosts_with_ips": sum(1 for h in host_list if h['ips']),
|
||||
"hosts_with_users": sum(1 for h in host_list if h['users']),
|
||||
"row_budget_per_dataset": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET,
|
||||
"row_budget_total": settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS,
|
||||
"connection_budget": settings.NETWORK_INVENTORY_MAX_CONNECTIONS,
|
||||
"sampled_mode": settings.NETWORK_INVENTORY_MAX_ROWS_PER_DATASET > 0 or settings.NETWORK_INVENTORY_MAX_TOTAL_ROWS > 0,
|
||||
"sampled_datasets": sampled_dataset_count,
|
||||
"global_budget_reached": global_budget_reached,
|
||||
"dropped_connections": dropped_connections,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -18,6 +19,9 @@ logger = logging.getLogger(__name__)
|
||||
HEAVY_MODEL = settings.DEFAULT_HEAVY_MODEL
|
||||
WILE_URL = f"{settings.wile_url}/api/generate"
|
||||
|
||||
# Velociraptor client IDs (C.hex) are not real hostnames
|
||||
CLIENTID_RE = re.compile(r"^C\.[0-9a-fA-F]{8,}$")
|
||||
|
||||
|
||||
async def _get_triage_summary(db, dataset_id: str) -> str:
|
||||
result = await db.execute(
|
||||
@@ -154,7 +158,7 @@ async def profile_host(
|
||||
logger.info("Host profile %s: risk=%.1f level=%s", hostname, profile.risk_score, profile.risk_level)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to profile host %s: %s", hostname, e)
|
||||
logger.error("Failed to profile host %s: %r", hostname, e)
|
||||
profile = HostProfile(
|
||||
hunt_id=hunt_id, hostname=hostname, fqdn=fqdn,
|
||||
risk_score=0.0, risk_level="unknown",
|
||||
@@ -185,6 +189,13 @@ async def profile_all_hosts(hunt_id: str) -> None:
|
||||
if h not in hostnames:
|
||||
hostnames[h] = data.get("fqdn") or data.get("Fqdn")
|
||||
|
||||
# Filter out Velociraptor client IDs - not real hostnames
|
||||
real_hosts = {h: f for h, f in hostnames.items() if not CLIENTID_RE.match(h)}
|
||||
skipped = len(hostnames) - len(real_hosts)
|
||||
if skipped:
|
||||
logger.info("Skipped %d Velociraptor client IDs", skipped)
|
||||
hostnames = real_hosts
|
||||
|
||||
logger.info("Discovered %d unique hosts in hunt %s", len(hostnames), hunt_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(settings.HOST_PROFILE_CONCURRENCY)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Async job queue for background AI tasks.
|
||||
|
||||
Manages triage, profiling, report generation, anomaly detection,
|
||||
and data queries as trackable jobs with status, progress, and
|
||||
cancellation support.
|
||||
keyword scanning, IOC extraction, and data queries as trackable
|
||||
jobs with status, progress, and cancellation support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -15,6 +15,8 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,6 +34,18 @@ class JobType(str, Enum):
|
||||
REPORT = "report"
|
||||
ANOMALY = "anomaly"
|
||||
QUERY = "query"
|
||||
HOST_INVENTORY = "host_inventory"
|
||||
KEYWORD_SCAN = "keyword_scan"
|
||||
IOC_EXTRACT = "ioc_extract"
|
||||
|
||||
|
||||
# Job types that form the automatic upload pipeline
|
||||
PIPELINE_JOB_TYPES = frozenset({
|
||||
JobType.TRIAGE,
|
||||
JobType.ANOMALY,
|
||||
JobType.KEYWORD_SCAN,
|
||||
JobType.IOC_EXTRACT,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,11 +96,7 @@ class Job:
|
||||
|
||||
|
||||
class JobQueue:
|
||||
"""In-memory async job queue with concurrency control.
|
||||
|
||||
Jobs are tracked by ID and can be listed, polled, or cancelled.
|
||||
A configurable number of workers process jobs from the queue.
|
||||
"""
|
||||
"""In-memory async job queue with concurrency control."""
|
||||
|
||||
def __init__(self, max_workers: int = 3):
|
||||
self._jobs: dict[str, Job] = {}
|
||||
@@ -95,47 +105,56 @@ class JobQueue:
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._handlers: dict[JobType, Callable] = {}
|
||||
self._started = False
|
||||
self._completion_callbacks: list[Callable[[Job], Coroutine]] = []
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
job_type: JobType,
|
||||
handler: Callable[[Job], Coroutine],
|
||||
):
|
||||
"""Register an async handler for a job type.
|
||||
|
||||
Handler signature: async def handler(job: Job) -> Any
|
||||
The handler can update job.progress and job.message during execution.
|
||||
It should check job.is_cancelled periodically and return early.
|
||||
"""
|
||||
def register_handler(self, job_type: JobType, handler: Callable[[Job], Coroutine]):
|
||||
self._handlers[job_type] = handler
|
||||
logger.info(f"Registered handler for {job_type.value}")
|
||||
|
||||
def on_completion(self, callback: Callable[[Job], Coroutine]):
|
||||
"""Register a callback invoked after any job completes or fails."""
|
||||
self._completion_callbacks.append(callback)
|
||||
|
||||
async def start(self):
|
||||
"""Start worker tasks."""
|
||||
if self._started:
|
||||
return
|
||||
self._started = True
|
||||
for i in range(self._max_workers):
|
||||
task = asyncio.create_task(self._worker(i))
|
||||
self._workers.append(task)
|
||||
if not self._cleanup_task or self._cleanup_task.done():
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(f"Job queue started with {self._max_workers} workers")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all workers."""
|
||||
self._started = False
|
||||
for w in self._workers:
|
||||
w.cancel()
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
await asyncio.gather(self._cleanup_task, return_exceptions=True)
|
||||
self._cleanup_task = None
|
||||
logger.info("Job queue stopped")
|
||||
|
||||
def submit(self, job_type: JobType, **params) -> Job:
|
||||
"""Submit a new job. Returns the Job object immediately."""
|
||||
job = Job(
|
||||
id=str(uuid.uuid4()),
|
||||
job_type=job_type,
|
||||
params=params,
|
||||
)
|
||||
# Soft backpressure: prefer dedupe over queue amplification
|
||||
dedupe_job = self._find_active_duplicate(job_type, params)
|
||||
if dedupe_job is not None:
|
||||
logger.info(
|
||||
f"Job deduped: reusing {dedupe_job.id} ({job_type.value}) params={params}"
|
||||
)
|
||||
return dedupe_job
|
||||
|
||||
if self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG:
|
||||
logger.warning(
|
||||
"Job queue backlog high (%d >= %d). Accepting job but system may be degraded.",
|
||||
self._queue.qsize(), settings.JOB_QUEUE_MAX_BACKLOG,
|
||||
)
|
||||
|
||||
job = Job(id=str(uuid.uuid4()), job_type=job_type, params=params)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put_nowait(job.id)
|
||||
logger.info(f"Job submitted: {job.id} ({job_type.value}) params={params}")
|
||||
@@ -144,6 +163,22 @@ class JobQueue:
|
||||
def get_job(self, job_id: str) -> Job | None:
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def _find_active_duplicate(self, job_type: JobType, params: dict) -> Job | None:
|
||||
"""Return queued/running job with same key workload to prevent duplicate storms."""
|
||||
key_fields = ["dataset_id", "hunt_id", "hostname", "question", "mode"]
|
||||
sig = tuple((k, params.get(k)) for k in key_fields if params.get(k) is not None)
|
||||
if not sig:
|
||||
return None
|
||||
for j in self._jobs.values():
|
||||
if j.job_type != job_type:
|
||||
continue
|
||||
if j.status not in (JobStatus.QUEUED, JobStatus.RUNNING):
|
||||
continue
|
||||
other_sig = tuple((k, j.params.get(k)) for k in key_fields if j.params.get(k) is not None)
|
||||
if sig == other_sig:
|
||||
return j
|
||||
return None
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
job = self._jobs.get(job_id)
|
||||
if not job:
|
||||
@@ -153,13 +188,7 @@ class JobQueue:
|
||||
job.cancel()
|
||||
return True
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: JobStatus | None = None,
|
||||
job_type: JobType | None = None,
|
||||
limit: int = 50,
|
||||
) -> list[dict]:
|
||||
"""List jobs, newest first."""
|
||||
def list_jobs(self, status=None, job_type=None, limit=50) -> list[dict]:
|
||||
jobs = sorted(self._jobs.values(), key=lambda j: j.created_at, reverse=True)
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
@@ -168,7 +197,6 @@ class JobQueue:
|
||||
return [j.to_dict() for j in jobs[:limit]]
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get queue statistics."""
|
||||
by_status = {}
|
||||
for j in self._jobs.values():
|
||||
by_status[j.status.value] = by_status.get(j.status.value, 0) + 1
|
||||
@@ -177,26 +205,58 @@ class JobQueue:
|
||||
"queued": self._queue.qsize(),
|
||||
"by_status": by_status,
|
||||
"workers": self._max_workers,
|
||||
"active_workers": sum(
|
||||
1 for j in self._jobs.values() if j.status == JobStatus.RUNNING
|
||||
),
|
||||
"active_workers": sum(1 for j in self._jobs.values() if j.status == JobStatus.RUNNING),
|
||||
}
|
||||
|
||||
def is_backlogged(self) -> bool:
|
||||
return self._queue.qsize() >= settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def can_accept(self, reserve: int = 0) -> bool:
|
||||
return (self._queue.qsize() + max(0, reserve)) < settings.JOB_QUEUE_MAX_BACKLOG
|
||||
|
||||
def cleanup(self, max_age_seconds: float = 3600):
|
||||
"""Remove old completed/failed/cancelled jobs."""
|
||||
now = time.time()
|
||||
terminal_states = (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
to_remove = [
|
||||
jid for jid, j in self._jobs.items()
|
||||
if j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
and (now - j.created_at) > max_age_seconds
|
||||
if j.status in terminal_states and (now - j.created_at) > max_age_seconds
|
||||
]
|
||||
|
||||
# Also cap retained terminal jobs to avoid unbounded memory growth
|
||||
terminal_jobs = sorted(
|
||||
[j for j in self._jobs.values() if j.status in terminal_states],
|
||||
key=lambda j: j.created_at,
|
||||
reverse=True,
|
||||
)
|
||||
overflow = terminal_jobs[settings.JOB_QUEUE_RETAIN_COMPLETED :]
|
||||
to_remove.extend([j.id for j in overflow])
|
||||
|
||||
removed = 0
|
||||
for jid in set(to_remove):
|
||||
if jid in self._jobs:
|
||||
del self._jobs[jid]
|
||||
removed += 1
|
||||
if removed:
|
||||
logger.info(f"Cleaned up {removed} old jobs")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
interval = max(10, settings.JOB_QUEUE_CLEANUP_INTERVAL_SECONDS)
|
||||
while self._started:
|
||||
try:
|
||||
self.cleanup(max_age_seconds=settings.JOB_QUEUE_CLEANUP_MAX_AGE_SECONDS)
|
||||
except Exception as e:
|
||||
logger.warning(f"Job queue cleanup loop error: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
def find_pipeline_jobs(self, dataset_id: str) -> list[Job]:
|
||||
"""Find all pipeline jobs for a given dataset_id."""
|
||||
return [
|
||||
j for j in self._jobs.values()
|
||||
if j.job_type in PIPELINE_JOB_TYPES
|
||||
and j.params.get("dataset_id") == dataset_id
|
||||
]
|
||||
for jid in to_remove:
|
||||
del self._jobs[jid]
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs")
|
||||
|
||||
async def _worker(self, worker_id: int):
|
||||
"""Worker loop: pull jobs from queue and execute handlers."""
|
||||
logger.info(f"Worker {worker_id} started")
|
||||
while self._started:
|
||||
try:
|
||||
@@ -220,7 +280,10 @@ class JobQueue:
|
||||
|
||||
job.status = JobStatus.RUNNING
|
||||
job.started_at = time.time()
|
||||
if job.progress <= 0:
|
||||
job.progress = 5.0
|
||||
job.message = "Running..."
|
||||
await _sync_processing_task(job)
|
||||
logger.info(f"Worker {worker_id}: executing {job.id} ({job.job_type.value})")
|
||||
|
||||
try:
|
||||
@@ -231,38 +294,111 @@ class JobQueue:
|
||||
job.result = result
|
||||
job.message = "Completed"
|
||||
job.completed_at = time.time()
|
||||
logger.info(
|
||||
f"Worker {worker_id}: completed {job.id} "
|
||||
f"in {job.elapsed_ms}ms"
|
||||
)
|
||||
logger.info(f"Worker {worker_id}: completed {job.id} in {job.elapsed_ms}ms")
|
||||
except Exception as e:
|
||||
if not job.is_cancelled:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.message = f"Failed: {e}"
|
||||
job.completed_at = time.time()
|
||||
logger.error(
|
||||
f"Worker {worker_id}: failed {job.id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error(f"Worker {worker_id}: failed {job.id}: {e}", exc_info=True)
|
||||
|
||||
if job.is_cancelled and not job.completed_at:
|
||||
job.completed_at = time.time()
|
||||
|
||||
await _sync_processing_task(job)
|
||||
|
||||
# Fire completion callbacks
|
||||
for cb in self._completion_callbacks:
|
||||
try:
|
||||
await cb(job)
|
||||
except Exception as cb_err:
|
||||
logger.error(f"Completion callback error: {cb_err}", exc_info=True)
|
||||
|
||||
|
||||
# Singleton + job handlers
|
||||
async def _sync_processing_task(job: Job):
|
||||
"""Persist latest job state into processing_tasks (if linked by job_id)."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
job_queue = JobQueue(max_workers=3)
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
values = {
|
||||
"status": job.status.value,
|
||||
"progress": float(job.progress),
|
||||
"message": job.message,
|
||||
"error": job.error,
|
||||
}
|
||||
if job.started_at:
|
||||
values["started_at"] = datetime.fromtimestamp(job.started_at, tz=timezone.utc)
|
||||
if job.completed_at:
|
||||
values["completed_at"] = datetime.fromtimestamp(job.completed_at, tz=timezone.utc)
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.job_id == job.id)
|
||||
.values(**values)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync processing task for job {job.id}: {e}")
|
||||
|
||||
|
||||
# -- Singleton + job handlers --
|
||||
|
||||
job_queue = JobQueue(max_workers=5)
|
||||
|
||||
|
||||
async def _handle_triage(job: Job):
|
||||
"""Triage handler."""
|
||||
"""Triage handler - chains HOST_PROFILE after completion."""
|
||||
from app.services.triage import triage_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Triaging dataset {dataset_id}"
|
||||
results = await triage_dataset(dataset_id)
|
||||
return {"count": len(results) if results else 0}
|
||||
await triage_dataset(dataset_id)
|
||||
|
||||
# Chain: trigger host profiling now that triage results exist
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import select
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
ds = await db.execute(select(Dataset.hunt_id).where(Dataset.id == dataset_id))
|
||||
row = ds.first()
|
||||
hunt_id = row[0] if row else None
|
||||
if hunt_id:
|
||||
hp_job = job_queue.submit(JobType.HOST_PROFILE, hunt_id=hunt_id)
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from app.db.models import ProcessingTask
|
||||
async with async_session_factory() as db:
|
||||
existing = await db.execute(
|
||||
select(ProcessingTask.id).where(ProcessingTask.job_id == hp_job.id)
|
||||
)
|
||||
if existing.first() is None:
|
||||
db.add(ProcessingTask(
|
||||
hunt_id=hunt_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=hp_job.id,
|
||||
stage="host_profile",
|
||||
status="queued",
|
||||
progress=0.0,
|
||||
message="Queued",
|
||||
))
|
||||
await db.commit()
|
||||
except Exception as persist_err:
|
||||
logger.warning(f"Failed to persist chained HOST_PROFILE task: {persist_err}")
|
||||
|
||||
logger.info(f"Triage done for {dataset_id} - chained HOST_PROFILE for hunt {hunt_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to chain host profile after triage: {e}")
|
||||
|
||||
return {"dataset_id": dataset_id}
|
||||
|
||||
|
||||
async def _handle_host_profile(job: Job):
|
||||
"""Host profiling handler."""
|
||||
from app.services.host_profiler import profile_all_hosts, profile_host
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
hostname = job.params.get("hostname")
|
||||
@@ -277,7 +413,6 @@ async def _handle_host_profile(job: Job):
|
||||
|
||||
|
||||
async def _handle_report(job: Job):
|
||||
"""Report generation handler."""
|
||||
from app.services.report_generator import generate_report
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
job.message = f"Generating report for hunt {hunt_id}"
|
||||
@@ -286,7 +421,6 @@ async def _handle_report(job: Job):
|
||||
|
||||
|
||||
async def _handle_anomaly(job: Job):
|
||||
"""Anomaly detection handler."""
|
||||
from app.services.anomaly_detector import detect_anomalies
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
k = job.params.get("k", 3)
|
||||
@@ -297,7 +431,6 @@ async def _handle_anomaly(job: Job):
|
||||
|
||||
|
||||
async def _handle_query(job: Job):
|
||||
"""Data query handler (non-streaming)."""
|
||||
from app.services.data_query import query_dataset
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
question = job.params.get("question", "")
|
||||
@@ -307,10 +440,152 @@ async def _handle_query(job: Job):
|
||||
return {"answer": answer}
|
||||
|
||||
|
||||
async def _handle_host_inventory(job: Job):
|
||||
from app.db import async_session_factory
|
||||
from app.services.host_inventory import build_host_inventory, inventory_cache
|
||||
|
||||
hunt_id = job.params.get("hunt_id")
|
||||
if not hunt_id:
|
||||
raise ValueError("hunt_id required")
|
||||
|
||||
inventory_cache.set_building(hunt_id)
|
||||
job.message = f"Building host inventory for hunt {hunt_id}"
|
||||
|
||||
try:
|
||||
async with async_session_factory() as db:
|
||||
result = await build_host_inventory(hunt_id, db)
|
||||
inventory_cache.put(hunt_id, result)
|
||||
job.message = f"Built inventory: {result['stats']['total_hosts']} hosts"
|
||||
return {"hunt_id": hunt_id, "total_hosts": result["stats"]["total_hosts"]}
|
||||
except Exception:
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_keyword_scan(job: Job):
|
||||
"""AUP keyword scan handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.scanner import KeywordScanner, keyword_scan_cache
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Running AUP keyword scan on dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
scanner = KeywordScanner(db)
|
||||
result = await scanner.scan(dataset_ids=[dataset_id])
|
||||
|
||||
# Cache dataset-only result for fast API reuse
|
||||
if dataset_id:
|
||||
keyword_scan_cache.put(dataset_id, result)
|
||||
|
||||
hits = result.get("total_hits", 0)
|
||||
job.message = f"Keyword scan complete: {hits} hits"
|
||||
logger.info(f"Keyword scan for {dataset_id}: {hits} hits across {result.get('rows_scanned', 0)} rows")
|
||||
return {"dataset_id": dataset_id, "total_hits": hits, "rows_scanned": result.get("rows_scanned", 0)}
|
||||
|
||||
|
||||
async def _handle_ioc_extract(job: Job):
|
||||
"""IOC extraction handler."""
|
||||
from app.db import async_session_factory
|
||||
from app.services.ioc_extractor import extract_iocs_from_dataset
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
job.message = f"Extracting IOCs from dataset {dataset_id}"
|
||||
|
||||
async with async_session_factory() as db:
|
||||
iocs = await extract_iocs_from_dataset(dataset_id, db)
|
||||
|
||||
total = sum(len(v) for v in iocs.values())
|
||||
job.message = f"IOC extraction complete: {total} IOCs found"
|
||||
logger.info(f"IOC extract for {dataset_id}: {total} IOCs")
|
||||
return {"dataset_id": dataset_id, "total_iocs": total, "breakdown": {k: len(v) for k, v in iocs.items()}}
|
||||
|
||||
|
||||
async def _on_pipeline_job_complete(job: Job):
|
||||
"""Update Dataset.processing_status when all pipeline jobs finish."""
|
||||
if job.job_type not in PIPELINE_JOB_TYPES:
|
||||
return
|
||||
|
||||
dataset_id = job.params.get("dataset_id")
|
||||
if not dataset_id:
|
||||
return
|
||||
|
||||
pipeline_jobs = job_queue.find_pipeline_jobs(dataset_id)
|
||||
if not pipeline_jobs:
|
||||
return
|
||||
|
||||
all_done = all(
|
||||
j.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED)
|
||||
for j in pipeline_jobs
|
||||
)
|
||||
if not all_done:
|
||||
return
|
||||
|
||||
any_failed = any(j.status == JobStatus.FAILED for j in pipeline_jobs)
|
||||
new_status = "completed_with_errors" if any_failed else "completed"
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import Dataset
|
||||
from sqlalchemy import update
|
||||
|
||||
async with async_session_factory() as db:
|
||||
await db.execute(
|
||||
update(Dataset)
|
||||
.where(Dataset.id == dataset_id)
|
||||
.values(processing_status=new_status)
|
||||
)
|
||||
await db.commit()
|
||||
logger.info(f"Dataset {dataset_id} processing_status -> {new_status}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update processing_status for {dataset_id}: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def reconcile_stale_processing_tasks() -> int:
|
||||
"""Mark queued/running processing tasks from prior runs as failed."""
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import update
|
||||
|
||||
try:
|
||||
from app.db import async_session_factory
|
||||
from app.db.models import ProcessingTask
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
async with async_session_factory() as db:
|
||||
result = await db.execute(
|
||||
update(ProcessingTask)
|
||||
.where(ProcessingTask.status.in_(["queued", "running"]))
|
||||
.values(
|
||||
status="failed",
|
||||
error="Recovered after service restart before task completion",
|
||||
message="Recovered stale task after restart",
|
||||
completed_at=now,
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
updated = int(result.rowcount or 0)
|
||||
|
||||
if updated:
|
||||
logger.warning(
|
||||
"Reconciled %d stale processing tasks (queued/running -> failed) during startup",
|
||||
updated,
|
||||
)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconcile stale processing tasks: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def register_all_handlers():
|
||||
"""Register all job handlers."""
|
||||
"""Register all job handlers and completion callbacks."""
|
||||
job_queue.register_handler(JobType.TRIAGE, _handle_triage)
|
||||
job_queue.register_handler(JobType.HOST_PROFILE, _handle_host_profile)
|
||||
job_queue.register_handler(JobType.REPORT, _handle_report)
|
||||
job_queue.register_handler(JobType.ANOMALY, _handle_anomaly)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
job_queue.register_handler(JobType.QUERY, _handle_query)
|
||||
job_queue.register_handler(JobType.HOST_INVENTORY, _handle_host_inventory)
|
||||
job_queue.register_handler(JobType.KEYWORD_SCAN, _handle_keyword_scan)
|
||||
job_queue.register_handler(JobType.IOC_EXTRACT, _handle_ioc_extract)
|
||||
job_queue.on_completion(_on_pipeline_job_complete)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""AUP Keyword Scanner — searches dataset rows, hunts, annotations, and
|
||||
"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and
|
||||
messages for keyword matches.
|
||||
|
||||
Scanning is done in Python (not SQL LIKE on JSON columns) for portability
|
||||
@@ -8,24 +8,49 @@ across SQLite / PostgreSQL and to provide per-cell match context.
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
|
||||
from app.db.models import (
|
||||
KeywordTheme,
|
||||
Keyword,
|
||||
DatasetRow,
|
||||
Dataset,
|
||||
Hunt,
|
||||
Annotation,
|
||||
Message,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_SIZE = 500
|
||||
BATCH_SIZE = 200
|
||||
|
||||
|
||||
def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]:
|
||||
"""Best-effort extraction of hostname and user from a dataset row."""
|
||||
if not data:
|
||||
return None, None
|
||||
|
||||
host_keys = (
|
||||
'hostname', 'host_name', 'host', 'computer_name', 'computer',
|
||||
'fqdn', 'client_id', 'agent_id', 'endpoint_id',
|
||||
)
|
||||
user_keys = (
|
||||
'username', 'user_name', 'user', 'account_name',
|
||||
'logged_in_user', 'samaccountname', 'sam_account_name',
|
||||
)
|
||||
|
||||
def pick(keys):
|
||||
for k in keys:
|
||||
for actual_key, v in data.items():
|
||||
if actual_key.lower() == k and v not in (None, ''):
|
||||
return str(v)
|
||||
return None
|
||||
|
||||
return pick(host_keys), pick(user_keys)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,6 +64,8 @@ class ScanHit:
|
||||
matched_value: str
|
||||
row_index: int | None = None
|
||||
dataset_name: str | None = None
|
||||
hostname: str | None = None
|
||||
username: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -50,21 +77,54 @@ class ScanResult:
|
||||
rows_scanned: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeywordScanCacheEntry:
|
||||
dataset_id: str
|
||||
result: dict
|
||||
built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
class KeywordScanCache:
|
||||
"""In-memory per-dataset cache for dataset-only keyword scans.
|
||||
|
||||
This enables fast-path reads when users run AUP scans against datasets that
|
||||
were already scanned during upload pipeline processing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._entries: dict[str, KeywordScanCacheEntry] = {}
|
||||
|
||||
def put(self, dataset_id: str, result: dict):
|
||||
self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result)
|
||||
|
||||
def get(self, dataset_id: str) -> KeywordScanCacheEntry | None:
|
||||
return self._entries.get(dataset_id)
|
||||
|
||||
def invalidate_dataset(self, dataset_id: str):
|
||||
self._entries.pop(dataset_id, None)
|
||||
|
||||
def clear(self):
|
||||
self._entries.clear()
|
||||
|
||||
|
||||
keyword_scan_cache = KeywordScanCache()
|
||||
|
||||
|
||||
class KeywordScanner:
|
||||
"""Scans multiple data sources for keyword/regex matches."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
# Public API
|
||||
|
||||
async def scan(
|
||||
self,
|
||||
dataset_ids: list[str] | None = None,
|
||||
theme_ids: list[str] | None = None,
|
||||
scan_hunts: bool = True,
|
||||
scan_annotations: bool = True,
|
||||
scan_messages: bool = True,
|
||||
scan_hunts: bool = False,
|
||||
scan_annotations: bool = False,
|
||||
scan_messages: bool = False,
|
||||
) -> dict:
|
||||
"""Run a full AUP scan and return dict matching ScanResponse."""
|
||||
# Load themes + keywords
|
||||
@@ -103,7 +163,7 @@ class KeywordScanner:
|
||||
"rows_scanned": result.rows_scanned,
|
||||
}
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────────
|
||||
# Internal
|
||||
|
||||
async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]:
|
||||
q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712
|
||||
@@ -143,6 +203,8 @@ class KeywordScanner:
|
||||
hits: list[ScanHit],
|
||||
row_index: int | None = None,
|
||||
dataset_name: str | None = None,
|
||||
hostname: str | None = None,
|
||||
username: str | None = None,
|
||||
) -> None:
|
||||
"""Check text against all compiled patterns, append hits."""
|
||||
if not text:
|
||||
@@ -150,8 +212,7 @@ class KeywordScanner:
|
||||
for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items():
|
||||
for kw_value, pat in keyword_patterns:
|
||||
if pat.search(text):
|
||||
# Truncate matched_value for display
|
||||
matched_preview = text[:200] + ("…" if len(text) > 200 else "")
|
||||
matched_preview = text[:200] + ("" if len(text) > 200 else "")
|
||||
hits.append(ScanHit(
|
||||
theme_name=theme_name,
|
||||
theme_color=theme_color,
|
||||
@@ -162,13 +223,14 @@ class KeywordScanner:
|
||||
matched_value=matched_preview,
|
||||
row_index=row_index,
|
||||
dataset_name=dataset_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
))
|
||||
|
||||
async def _scan_datasets(
|
||||
self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None
|
||||
) -> None:
|
||||
"""Scan dataset rows in batches."""
|
||||
# Build dataset name lookup
|
||||
"""Scan dataset rows in batches using keyset pagination (no OFFSET)."""
|
||||
ds_q = select(Dataset.id, Dataset.name)
|
||||
if dataset_ids:
|
||||
ds_q = ds_q.where(Dataset.id.in_(dataset_ids))
|
||||
@@ -178,37 +240,66 @@ class KeywordScanner:
|
||||
if not ds_map:
|
||||
return
|
||||
|
||||
# Iterate rows in batches
|
||||
offset = 0
|
||||
row_q_base = select(DatasetRow).where(
|
||||
DatasetRow.dataset_id.in_(list(ds_map.keys()))
|
||||
).order_by(DatasetRow.id)
|
||||
import asyncio
|
||||
|
||||
while True:
|
||||
rows_result = await self.db.execute(
|
||||
row_q_base.offset(offset).limit(BATCH_SIZE)
|
||||
max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN))
|
||||
budget_reached = False
|
||||
|
||||
for ds_id, ds_name in ds_map.items():
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
|
||||
last_id = 0
|
||||
while True:
|
||||
if max_rows and result.rows_scanned >= max_rows:
|
||||
budget_reached = True
|
||||
break
|
||||
rows_result = await self.db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == ds_id)
|
||||
.where(DatasetRow.id > last_id)
|
||||
.order_by(DatasetRow.id)
|
||||
.limit(BATCH_SIZE)
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
hostname, username = _infer_hostname_and_user(data)
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text,
|
||||
patterns,
|
||||
"dataset_row",
|
||||
row.id,
|
||||
col_name,
|
||||
result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_name,
|
||||
hostname=hostname,
|
||||
username=username,
|
||||
)
|
||||
|
||||
last_id = rows[-1].id
|
||||
await asyncio.sleep(0)
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
break
|
||||
|
||||
if budget_reached:
|
||||
logger.warning(
|
||||
"AUP scan row budget reached (%d rows). Returning partial results.",
|
||||
result.rows_scanned,
|
||||
)
|
||||
rows = rows_result.scalars().all()
|
||||
if not rows:
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
result.rows_scanned += 1
|
||||
data = row.data or {}
|
||||
for col_name, cell_value in data.items():
|
||||
if cell_value is None:
|
||||
continue
|
||||
text = str(cell_value)
|
||||
self._match_text(
|
||||
text, patterns, "dataset_row", row.id,
|
||||
col_name, result.hits,
|
||||
row_index=row.row_index,
|
||||
dataset_name=ds_map.get(row.dataset_id),
|
||||
)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
if len(rows) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None:
|
||||
"""Scan hunt names and descriptions."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
"""Auto-triage service - fast LLM analysis of dataset batches via Roadrunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.db.models import Dataset, DatasetRow, TriageResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAST_MODEL = "qwen2.5-coder:7b-instruct-q4_K_M"
|
||||
DEFAULT_FAST_MODEL = settings.DEFAULT_FAST_MODEL
|
||||
ROADRUNNER_URL = f"{settings.roadrunner_url}/api/generate"
|
||||
|
||||
ARTIFACT_FOCUS = {
|
||||
@@ -80,7 +80,7 @@ async def triage_dataset(dataset_id: str) -> None:
|
||||
rows_result = await db.execute(
|
||||
select(DatasetRow)
|
||||
.where(DatasetRow.dataset_id == dataset_id)
|
||||
.order_by(DatasetRow.row_number)
|
||||
.order_by(DatasetRow.row_index)
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
)
|
||||
@@ -167,4 +167,4 @@ Be precise. Only flag genuinely suspicious items. Respond with valid JSON only."
|
||||
|
||||
offset += batch_size
|
||||
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
logger.info("Triage complete for dataset %s", dataset_id)
|
||||
|
||||
124
backend/tests/test_agent_policy_execution.py
Normal file
124
backend/tests/test_agent_policy_execution.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for execution-mode behavior in /api/agent/assist."""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_policy_query_executes_scan(client):
|
||||
# 1) Create hunt
|
||||
h = await client.post("/api/hunts", json={"name": "Policy Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
# 2) Upload browser-history-like CSV
|
||||
csv_bytes = (
|
||||
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||
b"Bob,https://news.example.org/article,news,HOST-B,host-b.local\n"
|
||||
)
|
||||
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
# 3) Ensure policy theme/keyword exists
|
||||
t = await client.post(
|
||||
"/api/keywords/themes",
|
||||
json={
|
||||
"name": "Adult Content",
|
||||
"color": "#e91e63",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert t.status_code in (201, 409)
|
||||
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
assert themes.status_code == 200
|
||||
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||
|
||||
k = await client.post(
|
||||
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||
json={"value": "pornhub", "is_regex": False},
|
||||
)
|
||||
assert k.status_code in (201, 409)
|
||||
|
||||
# 4) Execution-mode query
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||
"hunt_id": hunt_id,
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
|
||||
assert body["model_used"] == "execution:keyword_scanner"
|
||||
assert body["execution"] is not None
|
||||
assert body["execution"]["policy_hits"] >= 1
|
||||
assert len(body["execution"]["top_user_hosts"]) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_execution_preference_off_stays_advisory(client):
|
||||
h = await client.post("/api/hunts", json={"name": "No Exec Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Analyze browser history for policy-violating domains and summarize by user and host.",
|
||||
"hunt_id": hunt_id,
|
||||
"execution_preference": "off",
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
assert body["model_used"] != "execution:keyword_scanner"
|
||||
assert body["execution"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_assist_execution_preference_force_executes(client):
|
||||
# Create hunt + dataset even when the query text is not policy-specific
|
||||
h = await client.post("/api/hunts", json={"name": "Force Exec Hunt"})
|
||||
assert h.status_code == 200
|
||||
hunt_id = h.json()["id"]
|
||||
|
||||
csv_bytes = (
|
||||
b"User,visited_url,title,ClientId,Fqdn\n"
|
||||
b"Alice,https://www.pornhub.com/view_video.php,site,HOST-A,host-a.local\n"
|
||||
)
|
||||
files = {"file": ("web_history.csv", io.BytesIO(csv_bytes), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
t = await client.post(
|
||||
"/api/keywords/themes",
|
||||
json={"name": "Adult Content", "color": "#e91e63", "enabled": True},
|
||||
)
|
||||
assert t.status_code in (201, 409)
|
||||
|
||||
themes = await client.get("/api/keywords/themes")
|
||||
assert themes.status_code == 200
|
||||
adult = next(x for x in themes.json()["themes"] if x["name"] == "Adult Content")
|
||||
k = await client.post(
|
||||
f"/api/keywords/themes/{adult['id']}/keywords",
|
||||
json={"value": "pornhub", "is_regex": False},
|
||||
)
|
||||
assert k.status_code in (201, 409)
|
||||
|
||||
q = await client.post(
|
||||
"/api/agent/assist",
|
||||
json={
|
||||
"query": "Summarize notable activity in this hunt.",
|
||||
"hunt_id": hunt_id,
|
||||
"execution_preference": "force",
|
||||
},
|
||||
)
|
||||
assert q.status_code == 200
|
||||
body = q.json()
|
||||
assert body["model_used"] == "execution:keyword_scanner"
|
||||
assert body["execution"] is not None
|
||||
@@ -77,6 +77,26 @@ class TestHuntEndpoints:
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
async def test_hunt_progress(self, client):
|
||||
create = await client.post("/api/hunts", json={"name": "Progress Hunt"})
|
||||
hunt_id = create.json()["id"]
|
||||
|
||||
# attach one dataset so progress has scope
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("progress.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
up = await client.post(f"/api/datasets/upload?hunt_id={hunt_id}", files=files)
|
||||
assert up.status_code == 200
|
||||
|
||||
res = await client.get(f"/api/hunts/{hunt_id}/progress")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["hunt_id"] == hunt_id
|
||||
assert "progress_percent" in body
|
||||
assert "dataset_total" in body
|
||||
assert "network_status" in body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDatasetEndpoints:
|
||||
"""Test dataset upload and retrieval."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
"""Tests for CSV parser and normalizer services."""
|
||||
|
||||
import pytest
|
||||
from app.services.csv_parser import parse_csv_bytes, detect_encoding, detect_delimiter, infer_column_types
|
||||
@@ -43,8 +43,9 @@ class TestCSVParser:
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_parse_empty_file(self):
|
||||
with pytest.raises(Exception):
|
||||
parse_csv_bytes(b"")
|
||||
rows, meta = parse_csv_bytes(b"")
|
||||
assert len(rows) == 0
|
||||
assert meta["row_count"] == 0
|
||||
|
||||
def test_detect_encoding_utf8(self):
|
||||
enc = detect_encoding(SAMPLE_CSV)
|
||||
@@ -53,17 +54,15 @@ class TestCSVParser:
|
||||
|
||||
def test_infer_column_types(self):
|
||||
types = infer_column_types(
|
||||
["192.168.1.1", "10.0.0.1", "8.8.8.8"],
|
||||
"src_ip",
|
||||
[{"src_ip": "192.168.1.1"}, {"src_ip": "10.0.0.1"}, {"src_ip": "8.8.8.8"}],
|
||||
)
|
||||
assert types == "ip"
|
||||
assert types["src_ip"] == "ip"
|
||||
|
||||
def test_infer_column_types_hash(self):
|
||||
types = infer_column_types(
|
||||
["d41d8cd98f00b204e9800998ecf8427e"],
|
||||
"hash",
|
||||
[{"hash": "d41d8cd98f00b204e9800998ecf8427e"}],
|
||||
)
|
||||
assert types == "hash_md5"
|
||||
assert types["hash"] == "hash_md5"
|
||||
|
||||
|
||||
class TestNormalizer:
|
||||
@@ -94,7 +93,7 @@ class TestNormalizer:
|
||||
start, end = detect_time_range(rows, column_mapping)
|
||||
# Should detect time range from timestamp column
|
||||
if start:
|
||||
assert "2025" in start
|
||||
assert "2025" in str(start)
|
||||
|
||||
def test_normalize_rows(self):
|
||||
rows = [{"SourceAddr": "10.0.0.1", "ProcessName": "cmd.exe"}]
|
||||
@@ -102,3 +101,6 @@ class TestNormalizer:
|
||||
normalized = normalize_rows(rows, mapping)
|
||||
assert len(normalized) == 1
|
||||
assert normalized[0].get("src_ip") == "10.0.0.1"
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -197,3 +197,27 @@ async def test_quick_scan(client: AsyncClient):
|
||||
assert "total_hits" in data
|
||||
# powershell should match at least one row
|
||||
assert data["total_hits"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_scan_cache_hit(client: AsyncClient):
|
||||
"""Second quick scan should return cache hit metadata."""
|
||||
theme_res = await client.post("/api/keywords/themes", json={"name": "Quick Cache Theme", "color": "#00aa00"})
|
||||
tid = theme_res.json()["id"]
|
||||
await client.post(f"/api/keywords/themes/{tid}/keywords", json={"value": "chrome.exe"})
|
||||
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
import io
|
||||
files = {"file": ("cache_quick.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
upload = await client.post("/api/datasets/upload", files=files)
|
||||
ds_id = upload.json()["id"]
|
||||
|
||||
first = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert first.status_code == 200
|
||||
assert first.json().get("cache_status") in ("miss", "hit")
|
||||
|
||||
second = await client.get(f"/api/keywords/scan/quick?dataset_id={ds_id}")
|
||||
assert second.status_code == 200
|
||||
body = second.json()
|
||||
assert body.get("cache_used") is True
|
||||
assert body.get("cache_status") == "hit"
|
||||
|
||||
84
backend/tests/test_network.py
Normal file
84
backend/tests/test_network.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Tests for network inventory endpoints and cache/polling behavior."""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.host_inventory import inventory_cache
|
||||
from tests.conftest import SAMPLE_CSV
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inventory_status_none_for_unknown_hunt(client):
|
||||
hunt_id = "hunt-does-not-exist"
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["hunt_id"] == hunt_id
|
||||
assert body["status"] == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_inventory_cold_cache_returns_202(client):
|
||||
# Create hunt and upload dataset linked to that hunt
|
||||
hunt = await client.post("/api/hunts", json={"name": "Net Hunt"})
|
||||
hunt_id = hunt.json()["id"]
|
||||
|
||||
files = {"file": ("network.csv", io.BytesIO(SAMPLE_CSV), "text/csv")}
|
||||
up = await client.post("/api/datasets/upload", files=files, params={"hunt_id": hunt_id})
|
||||
assert up.status_code == 200
|
||||
|
||||
# Ensure cache is cold for this hunt
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 202
|
||||
body = res.json()
|
||||
assert body["status"] == "building"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_host_inventory_ready_cache_returns_200(client):
|
||||
hunt = await client.post("/api/hunts", json={"name": "Ready Hunt"})
|
||||
hunt_id = hunt.json()["id"]
|
||||
|
||||
mock_inventory = {
|
||||
"hosts": [
|
||||
{
|
||||
"id": "host-1",
|
||||
"hostname": "HOST-1",
|
||||
"fqdn": "HOST-1.local",
|
||||
"client_id": "C.1234abcd",
|
||||
"ips": ["10.0.0.10"],
|
||||
"os": "Windows 10",
|
||||
"users": ["alice"],
|
||||
"datasets": ["test"],
|
||||
"row_count": 5,
|
||||
}
|
||||
],
|
||||
"connections": [],
|
||||
"stats": {
|
||||
"total_hosts": 1,
|
||||
"hosts_with_ips": 1,
|
||||
"hosts_with_users": 1,
|
||||
"total_datasets_scanned": 1,
|
||||
"total_rows_scanned": 5,
|
||||
},
|
||||
}
|
||||
|
||||
inventory_cache.put(hunt_id, mock_inventory)
|
||||
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["stats"]["total_hosts"] == 1
|
||||
assert len(body["hosts"]) == 1
|
||||
assert body["hosts"][0]["hostname"] == "HOST-1"
|
||||
|
||||
status_res = await client.get(f"/api/network/inventory-status?hunt_id={hunt_id}")
|
||||
assert status_res.status_code == 200
|
||||
assert status_res.json()["status"] == "ready"
|
||||
82
backend/tests/test_network_scale.py
Normal file
82
backend/tests/test_network_scale.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Scale-oriented network endpoint tests (summary/subgraph/backpressure)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.config import settings
|
||||
from app.services.host_inventory import inventory_cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_summary_from_cache(client):
|
||||
hunt_id = "scale-hunt-summary"
|
||||
inv = {
|
||||
"hosts": [
|
||||
{"id": "h1", "hostname": "H1", "ips": ["10.0.0.1"], "users": ["a"], "row_count": 50},
|
||||
{"id": "h2", "hostname": "H2", "ips": [], "users": [], "row_count": 10},
|
||||
],
|
||||
"connections": [
|
||||
{"source": "h1", "target": "8.8.8.8", "count": 7},
|
||||
{"source": "h1", "target": "h2", "count": 3},
|
||||
],
|
||||
"stats": {"total_hosts": 2, "total_rows_scanned": 60},
|
||||
}
|
||||
inventory_cache.put(hunt_id, inv)
|
||||
|
||||
res = await client.get(f"/api/network/summary?hunt_id={hunt_id}&top_n=1")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert body["stats"]["total_hosts"] == 2
|
||||
assert len(body["top_hosts"]) == 1
|
||||
assert body["top_hosts"][0]["id"] == "h1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_subgraph_truncates(client):
|
||||
hunt_id = "scale-hunt-subgraph"
|
||||
inv = {
|
||||
"hosts": [
|
||||
{"id": f"h{i}", "hostname": f"H{i}", "ips": [], "users": [], "row_count": 100 - i}
|
||||
for i in range(1, 8)
|
||||
],
|
||||
"connections": [
|
||||
{"source": "h1", "target": "h2", "count": 20},
|
||||
{"source": "h1", "target": "h3", "count": 15},
|
||||
{"source": "h2", "target": "h4", "count": 5},
|
||||
{"source": "h3", "target": "h5", "count": 4},
|
||||
],
|
||||
"stats": {"total_hosts": 7, "total_rows_scanned": 999},
|
||||
}
|
||||
inventory_cache.put(hunt_id, inv)
|
||||
|
||||
res = await client.get(f"/api/network/subgraph?hunt_id={hunt_id}&max_hosts=3&max_edges=2")
|
||||
assert res.status_code == 200
|
||||
body = res.json()
|
||||
assert len(body["hosts"]) <= 3
|
||||
assert len(body["connections"]) <= 2
|
||||
assert body["stats"]["truncated"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_job_submit_backpressure_returns_429(client):
|
||||
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||
try:
|
||||
res = await client.post("/api/analysis/jobs/submit/triage", json={"params": {"dataset_id": "abc"}})
|
||||
assert res.status_code == 429
|
||||
finally:
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_host_inventory_deferred_when_queue_backlogged(client):
|
||||
hunt_id = "deferred-hunt"
|
||||
inventory_cache.invalidate(hunt_id)
|
||||
inventory_cache.clear_building(hunt_id)
|
||||
|
||||
old = settings.JOB_QUEUE_MAX_BACKLOG
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = 0
|
||||
try:
|
||||
res = await client.get(f"/api/network/host-inventory?hunt_id={hunt_id}")
|
||||
assert res.status_code == 202
|
||||
body = res.json()
|
||||
assert body["status"] == "deferred"
|
||||
finally:
|
||||
settings.JOB_QUEUE_MAX_BACKLOG = old
|
||||
203
backend/tests/test_new_features.py
Normal file
203
backend/tests/test_new_features.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests for new feature API routes: MITRE, Timeline, Playbooks, Saved Searches."""
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
|
||||
class TestMitreRoutes:
|
||||
"""Tests for /api/mitre endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mitre_coverage_empty(self, client):
|
||||
resp = await client.get("/api/mitre/coverage")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "tactics" in data
|
||||
assert "technique_count" in data
|
||||
assert data["technique_count"] == 0
|
||||
assert len(data["tactics"]) == 14 # 14 MITRE tactics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mitre_coverage_with_hunt_filter(self, client):
|
||||
resp = await client.get("/api/mitre/coverage?hunt_id=nonexistent")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["technique_count"] == 0
|
||||
|
||||
|
||||
class TestTimelineRoutes:
|
||||
"""Tests for /api/timeline endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeline_hunt_not_found(self, client):
|
||||
resp = await client.get("/api/timeline/hunt/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeline_with_hunt(self, client):
|
||||
# Create a hunt first
|
||||
hunt_resp = await client.post("/api/hunts", json={"name": "Timeline Test"})
|
||||
assert hunt_resp.status_code in (200, 201)
|
||||
hunt_id = hunt_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/timeline/hunt/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["hunt_id"] == hunt_id
|
||||
assert "events" in data
|
||||
assert "datasets" in data
|
||||
|
||||
|
||||
class TestPlaybookRoutes:
|
||||
"""Tests for /api/playbooks endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_playbooks_empty(self, client):
|
||||
resp = await client.get("/api/playbooks")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["playbooks"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_templates(self, client):
|
||||
resp = await client.get("/api/playbooks/templates")
|
||||
assert resp.status_code == 200
|
||||
templates = resp.json()["templates"]
|
||||
assert len(templates) >= 2
|
||||
assert templates[0]["name"] == "Standard Threat Hunt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_playbook(self, client):
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "My Investigation",
|
||||
"description": "Test playbook",
|
||||
"steps": [
|
||||
{"title": "Step 1", "description": "Upload data", "step_type": "upload", "target_route": "/upload"},
|
||||
{"title": "Step 2", "description": "Triage", "step_type": "analysis", "target_route": "/analysis"},
|
||||
],
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "My Investigation"
|
||||
assert len(data["steps"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_crud(self, client):
|
||||
# Create
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "CRUD Test",
|
||||
"steps": [{"title": "Do something"}],
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
pb_id = resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "CRUD Test"
|
||||
assert len(resp.json()["steps"]) == 1
|
||||
|
||||
# Update
|
||||
resp = await client.put(f"/api/playbooks/{pb_id}", json={"name": "Updated"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Delete
|
||||
resp = await client.delete(f"/api/playbooks/{pb_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playbook_step_completion(self, client):
|
||||
# Create with step
|
||||
resp = await client.post("/api/playbooks", json={
|
||||
"name": "Step Test",
|
||||
"steps": [{"title": "Task 1"}],
|
||||
})
|
||||
pb_id = resp.json()["id"]
|
||||
|
||||
# Get to find step ID
|
||||
resp = await client.get(f"/api/playbooks/{pb_id}")
|
||||
steps = resp.json()["steps"]
|
||||
step_id = steps[0]["id"]
|
||||
assert steps[0]["is_completed"] is False
|
||||
|
||||
# Mark complete
|
||||
resp = await client.put(f"/api/playbooks/steps/{step_id}", json={"is_completed": True, "notes": "Done!"})
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["is_completed"] is True
|
||||
|
||||
|
||||
class TestSavedSearchRoutes:
|
||||
"""Tests for /api/searches endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_empty(self, client):
|
||||
resp = await client.get("/api/searches")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["searches"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_saved_search(self, client):
|
||||
resp = await client.post("/api/searches", json={
|
||||
"name": "Suspicious IPs",
|
||||
"search_type": "ioc_search",
|
||||
"query_params": {"ioc_value": "203.0.113"},
|
||||
})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "Suspicious IPs"
|
||||
assert data["search_type"] == "ioc_search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_crud(self, client):
|
||||
# Create
|
||||
resp = await client.post("/api/searches", json={
|
||||
"name": "Test Query",
|
||||
"search_type": "keyword_scan",
|
||||
"query_params": {"theme": "malware"},
|
||||
})
|
||||
s_id = resp.json()["id"]
|
||||
|
||||
# Get
|
||||
resp = await client.get(f"/api/searches/{s_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Update
|
||||
resp = await client.put(f"/api/searches/{s_id}", json={"name": "Updated Query"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Run
|
||||
resp = await client.post(f"/api/searches/{s_id}/run")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "result_count" in data
|
||||
assert "delta" in data
|
||||
|
||||
# Delete
|
||||
resp = await client.delete(f"/api/searches/{s_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
|
||||
class TestStixExport:
|
||||
"""Tests for /api/export/stix endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stix_export_hunt_not_found(self, client):
|
||||
resp = await client.get("/api/export/stix/nonexistent-id")
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stix_export_empty_hunt(self, client):
|
||||
"""Export from a real hunt with no data returns valid but minimal bundle."""
|
||||
hunt_resp = await client.post("/api/hunts", json={"name": "STIX Test Hunt"})
|
||||
assert hunt_resp.status_code in (200, 201)
|
||||
hunt_id = hunt_resp.json()["id"]
|
||||
|
||||
resp = await client.get(f"/api/export/stix/{hunt_id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["type"] == "bundle"
|
||||
assert data["objects"][0]["spec_version"] == "2.1" # spec_version is on objects, not bundle
|
||||
assert "objects" in data
|
||||
# At minimum should have the identity object
|
||||
types = [o["type"] for o in data["objects"]]
|
||||
assert "identity" in types
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user