"""AUP Keyword Scanner searches dataset rows, hunts, annotations, and messages for keyword matches. Scanning is done in Python (not SQL LIKE on JSON columns) for portability across SQLite / PostgreSQL and to provide per-cell match context. """ import logging import re from dataclasses import dataclass, field from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.db.models import ( KeywordTheme, DatasetRow, Dataset, Hunt, Annotation, Message, ) logger = logging.getLogger(__name__) BATCH_SIZE = 200 def _infer_hostname_and_user(data: dict) -> tuple[str | None, str | None]: """Best-effort extraction of hostname and user from a dataset row.""" if not data: return None, None host_keys = ( 'hostname', 'host_name', 'host', 'computer_name', 'computer', 'fqdn', 'client_id', 'agent_id', 'endpoint_id', ) user_keys = ( 'username', 'user_name', 'user', 'account_name', 'logged_in_user', 'samaccountname', 'sam_account_name', ) def pick(keys): for k in keys: for actual_key, v in data.items(): if actual_key.lower() == k and v not in (None, ''): return str(v) return None return pick(host_keys), pick(user_keys) @dataclass class ScanHit: theme_name: str theme_color: str keyword: str source_type: str # dataset_row | hunt | annotation | message source_id: str | int field: str matched_value: str row_index: int | None = None dataset_name: str | None = None hostname: str | None = None username: str | None = None @dataclass class ScanResult: total_hits: int = 0 hits: list[ScanHit] = field(default_factory=list) themes_scanned: int = 0 keywords_scanned: int = 0 rows_scanned: int = 0 @dataclass class KeywordScanCacheEntry: dataset_id: str result: dict built_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) class KeywordScanCache: """In-memory per-dataset cache for dataset-only keyword scans. This enables fast-path reads when users run AUP scans against datasets that were already scanned during upload pipeline processing. """ def __init__(self): self._entries: dict[str, KeywordScanCacheEntry] = {} def put(self, dataset_id: str, result: dict): self._entries[dataset_id] = KeywordScanCacheEntry(dataset_id=dataset_id, result=result) def get(self, dataset_id: str) -> KeywordScanCacheEntry | None: return self._entries.get(dataset_id) def invalidate_dataset(self, dataset_id: str): self._entries.pop(dataset_id, None) def clear(self): self._entries.clear() keyword_scan_cache = KeywordScanCache() class KeywordScanner: """Scans multiple data sources for keyword/regex matches.""" def __init__(self, db: AsyncSession): self.db = db # Public API async def scan( self, dataset_ids: list[str] | None = None, theme_ids: list[str] | None = None, scan_hunts: bool = False, scan_annotations: bool = False, scan_messages: bool = False, ) -> dict: """Run a full AUP scan and return dict matching ScanResponse.""" # Load themes + keywords themes = await self._load_themes(theme_ids) if not themes: return ScanResult().__dict__ # Pre-compile patterns per theme patterns = self._compile_patterns(themes) result = ScanResult( themes_scanned=len(themes), keywords_scanned=sum(len(kws) for kws in patterns.values()), ) # Scan dataset rows await self._scan_datasets(patterns, result, dataset_ids) # Scan hunts if scan_hunts: await self._scan_hunts(patterns, result) # Scan annotations if scan_annotations: await self._scan_annotations(patterns, result) # Scan messages if scan_messages: await self._scan_messages(patterns, result) result.total_hits = len(result.hits) return { "total_hits": result.total_hits, "hits": [h.__dict__ for h in result.hits], "themes_scanned": result.themes_scanned, "keywords_scanned": result.keywords_scanned, "rows_scanned": result.rows_scanned, } # Internal async def _load_themes(self, theme_ids: list[str] | None) -> list[KeywordTheme]: q = select(KeywordTheme).where(KeywordTheme.enabled == True) # noqa: E712 if theme_ids: q = q.where(KeywordTheme.id.in_(theme_ids)) result = await self.db.execute(q) return list(result.scalars().all()) def _compile_patterns( self, themes: list[KeywordTheme] ) -> dict[tuple[str, str, str], list[tuple[str, re.Pattern]]]: """Returns {(theme_id, theme_name, theme_color): [(keyword_value, compiled_pattern), ...]}""" patterns: dict[tuple[str, str, str], list[tuple[str, re.Pattern]]] = {} for theme in themes: key = (theme.id, theme.name, theme.color) compiled = [] for kw in theme.keywords: try: if kw.is_regex: pat = re.compile(kw.value, re.IGNORECASE) else: pat = re.compile(re.escape(kw.value), re.IGNORECASE) compiled.append((kw.value, pat)) except re.error: logger.warning("Invalid regex pattern '%s' in theme '%s', skipping", kw.value, theme.name) patterns[key] = compiled return patterns def _match_text( self, text: str, patterns: dict, source_type: str, source_id: str | int, field_name: str, hits: list[ScanHit], row_index: int | None = None, dataset_name: str | None = None, hostname: str | None = None, username: str | None = None, ) -> None: """Check text against all compiled patterns, append hits.""" if not text: return for (theme_id, theme_name, theme_color), keyword_patterns in patterns.items(): for kw_value, pat in keyword_patterns: if pat.search(text): matched_preview = text[:200] + ("" if len(text) > 200 else "") hits.append(ScanHit( theme_name=theme_name, theme_color=theme_color, keyword=kw_value, source_type=source_type, source_id=source_id, field=field_name, matched_value=matched_preview, row_index=row_index, dataset_name=dataset_name, hostname=hostname, username=username, )) async def _scan_datasets( self, patterns: dict, result: ScanResult, dataset_ids: list[str] | None ) -> None: """Scan dataset rows in batches using keyset pagination (no OFFSET).""" ds_q = select(Dataset.id, Dataset.name) if dataset_ids: ds_q = ds_q.where(Dataset.id.in_(dataset_ids)) ds_result = await self.db.execute(ds_q) ds_map = {r[0]: r[1] for r in ds_result.fetchall()} if not ds_map: return import asyncio max_rows = max(0, int(settings.SCANNER_MAX_ROWS_PER_SCAN)) budget_reached = False for ds_id, ds_name in ds_map.items(): if max_rows and result.rows_scanned >= max_rows: budget_reached = True break last_id = 0 while True: if max_rows and result.rows_scanned >= max_rows: budget_reached = True break rows_result = await self.db.execute( select(DatasetRow) .where(DatasetRow.dataset_id == ds_id) .where(DatasetRow.id > last_id) .order_by(DatasetRow.id) .limit(BATCH_SIZE) ) rows = rows_result.scalars().all() if not rows: break for row in rows: result.rows_scanned += 1 data = row.data or {} hostname, username = _infer_hostname_and_user(data) for col_name, cell_value in data.items(): if cell_value is None: continue text = str(cell_value) self._match_text( text, patterns, "dataset_row", row.id, col_name, result.hits, row_index=row.row_index, dataset_name=ds_name, hostname=hostname, username=username, ) last_id = rows[-1].id await asyncio.sleep(0) if len(rows) < BATCH_SIZE: break if budget_reached: break if budget_reached: logger.warning( "AUP scan row budget reached (%d rows). Returning partial results.", result.rows_scanned, ) async def _scan_hunts(self, patterns: dict, result: ScanResult) -> None: """Scan hunt names and descriptions.""" hunts_result = await self.db.execute(select(Hunt)) for hunt in hunts_result.scalars().all(): self._match_text(hunt.name, patterns, "hunt", hunt.id, "name", result.hits) if hunt.description: self._match_text(hunt.description, patterns, "hunt", hunt.id, "description", result.hits) async def _scan_annotations(self, patterns: dict, result: ScanResult) -> None: """Scan annotation text.""" ann_result = await self.db.execute(select(Annotation)) for ann in ann_result.scalars().all(): self._match_text(ann.text, patterns, "annotation", ann.id, "text", result.hits) async def _scan_messages(self, patterns: dict, result: ScanResult) -> None: """Scan conversation messages (user messages only).""" msg_result = await self.db.execute( select(Message).where(Message.role == "user") ) for msg in msg_result.scalars().all(): self._match_text(msg.content, patterns, "message", msg.id, "content", result.hits)