Files
ThreatHunt/backend/app/api/routes/datasets.py

428 lines
13 KiB
Python

"""API routes for dataset upload, listing, and management."""
import logging
import os
from pathlib import Path
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.db import get_db
from app.db.models import ProcessingTask
from app.db.repositories.datasets import DatasetRepository
from app.services.csv_parser import parse_csv_bytes, infer_column_types
from app.services.normalizer import (
normalize_columns,
normalize_rows,
detect_ioc_columns,
detect_time_range,
)
from app.services.artifact_classifier import classify_artifact, get_artifact_category
logger = logging.getLogger(__name__)
from app.services.job_queue import job_queue, JobType
from app.services.host_inventory import inventory_cache
from app.services.scanner import keyword_scan_cache
router = APIRouter(prefix="/api/datasets", tags=["datasets"])
ALLOWED_EXTENSIONS = {".csv", ".tsv", ".txt"}
# -- Response models --
class DatasetSummary(BaseModel):
id: str
name: str
filename: str
source_tool: str | None = None
row_count: int
column_schema: dict | None = None
normalized_columns: dict | None = None
ioc_columns: dict | None = None
file_size_bytes: int
encoding: str | None = None
delimiter: str | None = None
time_range_start: str | None = None
time_range_end: str | None = None
artifact_type: str | None = None
processing_status: str | None = None
hunt_id: str | None = None
created_at: str
class DatasetListResponse(BaseModel):
datasets: list[DatasetSummary]
total: int
class RowsResponse(BaseModel):
rows: list[dict]
total: int
offset: int
limit: int
class UploadResponse(BaseModel):
id: str
name: str
row_count: int
columns: list[str]
column_types: dict
normalized_columns: dict
ioc_columns: dict
artifact_type: str | None = None
processing_status: str
jobs_queued: list[str]
message: str
# -- Routes --
@router.post(
"/upload",
response_model=UploadResponse,
summary="Upload a CSV dataset",
description="Upload a CSV/TSV file for analysis. The file is parsed, columns normalized, "
"IOCs auto-detected, artifact type classified, and all processing jobs queued automatically.",
)
async def upload_dataset(
file: UploadFile = File(...),
name: str | None = Query(None, description="Display name for the dataset"),
source_tool: str | None = Query(None, description="Source tool (e.g., velociraptor)"),
hunt_id: str | None = Query(None, description="Hunt ID to associate with"),
db: AsyncSession = Depends(get_db),
):
"""Upload and parse a CSV dataset, then trigger full processing pipeline."""
# Validate file
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
ext = Path(file.filename).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"File type '{ext}' not allowed. Accepted: {', '.join(ALLOWED_EXTENSIONS)}",
)
# Read file bytes
raw_bytes = await file.read()
if len(raw_bytes) == 0:
raise HTTPException(status_code=400, detail="File is empty")
if len(raw_bytes) > settings.max_upload_bytes:
raise HTTPException(
status_code=413,
detail=f"File too large. Max size: {settings.MAX_UPLOAD_SIZE_MB} MB",
)
# Parse CSV
try:
rows, metadata = parse_csv_bytes(raw_bytes)
except Exception as e:
logger.error(f"CSV parse error: {e}")
raise HTTPException(
status_code=422,
detail=f"Failed to parse CSV: {str(e)}. Check encoding and format.",
)
if not rows:
raise HTTPException(status_code=422, detail="CSV file contains no data rows")
columns: list[str] = metadata["columns"]
column_types: dict = metadata["column_types"]
# Normalize columns
column_mapping = normalize_columns(columns)
normalized = normalize_rows(rows, column_mapping)
# Detect IOCs
ioc_columns = detect_ioc_columns(columns, column_types, column_mapping)
# Detect time range
time_start, time_end = detect_time_range(rows, column_mapping)
# Classify artifact type from column headers
artifact_type = classify_artifact(columns)
artifact_category = get_artifact_category(artifact_type)
logger.info(f"Artifact classification: {artifact_type} (category: {artifact_category})")
# Store in DB with processing_status = "processing"
repo = DatasetRepository(db)
dataset = await repo.create_dataset(
name=name or Path(file.filename).stem,
filename=file.filename,
source_tool=source_tool,
row_count=len(rows),
column_schema=column_types,
normalized_columns=column_mapping,
ioc_columns=ioc_columns,
file_size_bytes=len(raw_bytes),
encoding=metadata["encoding"],
delimiter=metadata["delimiter"],
time_range_start=time_start,
time_range_end=time_end,
hunt_id=hunt_id,
artifact_type=artifact_type,
processing_status="processing",
)
await repo.bulk_insert_rows(
dataset_id=dataset.id,
rows=rows,
normalized_rows=normalized,
)
logger.info(
f"Uploaded dataset '{dataset.name}': {len(rows)} rows, "
f"{len(columns)} columns, {len(ioc_columns)} IOC columns, "
f"artifact={artifact_type}"
)
# -- Queue full processing pipeline --
jobs_queued = []
task_rows: list[ProcessingTask] = []
# 1. AI Triage (chains to HOST_PROFILE automatically on completion)
triage_job = job_queue.submit(JobType.TRIAGE, dataset_id=dataset.id)
jobs_queued.append("triage")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=triage_job.id,
stage="triage",
status="queued",
progress=0.0,
message="Queued",
))
# 2. Anomaly detection (embedding-based outlier detection)
anomaly_job = job_queue.submit(JobType.ANOMALY, dataset_id=dataset.id)
jobs_queued.append("anomaly")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=anomaly_job.id,
stage="anomaly",
status="queued",
progress=0.0,
message="Queued",
))
# 3. AUP keyword scan
kw_job = job_queue.submit(JobType.KEYWORD_SCAN, dataset_id=dataset.id)
jobs_queued.append("keyword_scan")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=kw_job.id,
stage="keyword_scan",
status="queued",
progress=0.0,
message="Queued",
))
# 4. IOC extraction
ioc_job = job_queue.submit(JobType.IOC_EXTRACT, dataset_id=dataset.id)
jobs_queued.append("ioc_extract")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=ioc_job.id,
stage="ioc_extract",
status="queued",
progress=0.0,
message="Queued",
))
# 5. Host inventory (network map) - requires hunt_id
if hunt_id:
inventory_cache.invalidate(hunt_id)
inv_job = job_queue.submit(JobType.HOST_INVENTORY, hunt_id=hunt_id)
jobs_queued.append("host_inventory")
task_rows.append(ProcessingTask(
hunt_id=hunt_id,
dataset_id=dataset.id,
job_id=inv_job.id,
stage="host_inventory",
status="queued",
progress=0.0,
message="Queued",
))
if task_rows:
db.add_all(task_rows)
await db.flush()
logger.info(f"Queued {len(jobs_queued)} processing jobs for dataset {dataset.id}: {jobs_queued}")
return UploadResponse(
id=dataset.id,
name=dataset.name,
row_count=len(rows),
columns=columns,
column_types=column_types,
normalized_columns=column_mapping,
ioc_columns=ioc_columns,
artifact_type=artifact_type,
processing_status="processing",
jobs_queued=jobs_queued,
message=f"Successfully uploaded {len(rows)} rows. {len(jobs_queued)} processing jobs queued.",
)
@router.get(
"",
response_model=DatasetListResponse,
summary="List datasets",
)
async def list_datasets(
hunt_id: str | None = Query(None),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
datasets = await repo.list_datasets(hunt_id=hunt_id, limit=limit, offset=offset)
total = await repo.count_datasets(hunt_id=hunt_id)
return DatasetListResponse(
datasets=[
DatasetSummary(
id=ds.id,
name=ds.name,
filename=ds.filename,
source_tool=ds.source_tool,
row_count=ds.row_count,
column_schema=ds.column_schema,
normalized_columns=ds.normalized_columns,
ioc_columns=ds.ioc_columns,
file_size_bytes=ds.file_size_bytes,
encoding=ds.encoding,
delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(),
)
for ds in datasets
],
total=total,
)
@router.get(
"/{dataset_id}",
response_model=DatasetSummary,
summary="Get dataset details",
)
async def get_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
ds = await repo.get_dataset(dataset_id)
if not ds:
raise HTTPException(status_code=404, detail="Dataset not found")
return DatasetSummary(
id=ds.id,
name=ds.name,
filename=ds.filename,
source_tool=ds.source_tool,
row_count=ds.row_count,
column_schema=ds.column_schema,
normalized_columns=ds.normalized_columns,
ioc_columns=ds.ioc_columns,
file_size_bytes=ds.file_size_bytes,
encoding=ds.encoding,
delimiter=ds.delimiter,
time_range_start=ds.time_range_start.isoformat() if ds.time_range_start else None,
time_range_end=ds.time_range_end.isoformat() if ds.time_range_end else None,
artifact_type=ds.artifact_type,
processing_status=ds.processing_status,
hunt_id=ds.hunt_id,
created_at=ds.created_at.isoformat(),
)
@router.get(
"/{dataset_id}/rows",
response_model=RowsResponse,
summary="Get dataset rows",
)
async def get_dataset_rows(
dataset_id: str,
limit: int = Query(1000, ge=1, le=10000),
offset: int = Query(0, ge=0),
normalized: bool = Query(False, description="Return normalized column names"),
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
ds = await repo.get_dataset(dataset_id)
if not ds:
raise HTTPException(status_code=404, detail="Dataset not found")
rows = await repo.get_rows(dataset_id, limit=limit, offset=offset)
total = await repo.count_rows(dataset_id)
return RowsResponse(
rows=[
(r.normalized_data if normalized and r.normalized_data else r.data)
for r in rows
],
total=total,
offset=offset,
limit=limit,
)
@router.delete(
"/{dataset_id}",
summary="Delete a dataset",
)
async def delete_dataset(
dataset_id: str,
db: AsyncSession = Depends(get_db),
):
repo = DatasetRepository(db)
deleted = await repo.delete_dataset(dataset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Dataset not found")
keyword_scan_cache.invalidate_dataset(dataset_id)
return {"message": "Dataset deleted", "id": dataset_id}
@router.post(
"/rescan-ioc",
summary="Re-scan IOC columns for all datasets",
)
async def rescan_ioc_columns(
db: AsyncSession = Depends(get_db),
):
"""Re-run detect_ioc_columns on every dataset using current detection logic."""
repo = DatasetRepository(db)
all_ds = await repo.list_datasets(limit=10000)
updated = 0
for ds in all_ds:
columns = list((ds.column_schema or {}).keys())
if not columns:
continue
new_ioc = detect_ioc_columns(
columns,
ds.column_schema or {},
ds.normalized_columns or {},
)
if new_ioc != (ds.ioc_columns or {}):
ds.ioc_columns = new_ioc
updated += 1
await db.commit()
return {"message": f"Rescanned {len(all_ds)} datasets, updated {updated}"}