mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
253 lines
7.9 KiB
Python
253 lines
7.9 KiB
Python
"""Network Picture — deduplicated host inventory built from dataset rows.
|
|
|
|
Scans all datasets in a hunt, extracts host-identifying fields from
|
|
normalized data, and groups by hostname (or src_ip fallback) to produce
|
|
a clean one-row-per-host inventory. Uses sets for deduplication —
|
|
if an IP appears 900 times, it shows once.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Sequence
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import Dataset, DatasetRow
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Canonical column names we extract per row
|
|
_HOST_KEYS = ("hostname",)
|
|
_IP_KEYS = ("src_ip", "dst_ip", "ip_address")
|
|
_USER_KEYS = ("username",)
|
|
_OS_KEYS = ("os",)
|
|
_MAC_KEYS = ("mac_address",)
|
|
_PORT_SRC_KEYS = ("src_port",)
|
|
_PORT_DST_KEYS = ("dst_port",)
|
|
_PROTO_KEYS = ("protocol",)
|
|
_STATE_KEYS = ("connection_state",)
|
|
_TS_KEYS = ("timestamp",)
|
|
|
|
# Junk values to skip
|
|
_JUNK = frozenset({"", "-", "0.0.0.0", "::", "0", "127.0.0.1", "::1", "localhost", "unknown", "n/a", "none", "null"})
|
|
|
|
ROW_BATCH = 1000 # rows fetched per DB query
|
|
MAX_HOSTS = 1000 # hard cap on returned hosts
|
|
|
|
|
|
def _clean(val: Any) -> str:
|
|
"""Normalise a cell value to a clean string or empty."""
|
|
s = (val if isinstance(val, str) else str(val) if val is not None else "").strip()
|
|
return "" if s.lower() in _JUNK else s
|
|
|
|
|
|
def _try_parse_ts(val: str) -> datetime | None:
|
|
"""Best-effort timestamp parse (subset of common formats)."""
|
|
for fmt in (
|
|
"%Y-%m-%dT%H:%M:%S.%fZ",
|
|
"%Y-%m-%dT%H:%M:%SZ",
|
|
"%Y-%m-%dT%H:%M:%S.%f",
|
|
"%Y-%m-%dT%H:%M:%S",
|
|
"%Y-%m-%d %H:%M:%S.%f",
|
|
"%Y-%m-%d %H:%M:%S",
|
|
):
|
|
try:
|
|
return datetime.strptime(val.strip(), fmt)
|
|
except ValueError:
|
|
continue
|
|
return None
|
|
|
|
|
|
class _HostBucket:
|
|
"""Mutable accumulator for a single host."""
|
|
|
|
__slots__ = (
|
|
"hostname", "ips", "users", "os_versions", "mac_addresses",
|
|
"protocols", "open_ports", "remote_targets", "datasets",
|
|
"connection_count", "first_seen", "last_seen",
|
|
)
|
|
|
|
def __init__(self, hostname: str):
|
|
self.hostname = hostname
|
|
self.ips: set[str] = set()
|
|
self.users: set[str] = set()
|
|
self.os_versions: set[str] = set()
|
|
self.mac_addresses: set[str] = set()
|
|
self.protocols: set[str] = set()
|
|
self.open_ports: set[str] = set()
|
|
self.remote_targets: set[str] = set()
|
|
self.datasets: set[str] = set()
|
|
self.connection_count: int = 0
|
|
self.first_seen: datetime | None = None
|
|
self.last_seen: datetime | None = None
|
|
|
|
def ingest(self, row: dict[str, Any], ds_name: str) -> None:
|
|
"""Merge one normalised row into this bucket."""
|
|
self.connection_count += 1
|
|
self.datasets.add(ds_name)
|
|
|
|
for k in _IP_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
self.ips.add(v)
|
|
|
|
for k in _USER_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
self.users.add(v)
|
|
|
|
for k in _OS_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
self.os_versions.add(v)
|
|
|
|
for k in _MAC_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
self.mac_addresses.add(v)
|
|
|
|
for k in _PROTO_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
self.protocols.add(v.upper())
|
|
|
|
# Open ports = local (src) ports
|
|
for k in _PORT_SRC_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v and v != "0":
|
|
self.open_ports.add(v)
|
|
|
|
# Remote targets = dst IPs
|
|
dst = _clean(row.get("dst_ip"))
|
|
if dst:
|
|
self.remote_targets.add(dst)
|
|
|
|
# Timestamps
|
|
for k in _TS_KEYS:
|
|
v = _clean(row.get(k))
|
|
if v:
|
|
ts = _try_parse_ts(v)
|
|
if ts:
|
|
if self.first_seen is None or ts < self.first_seen:
|
|
self.first_seen = ts
|
|
if self.last_seen is None or ts > self.last_seen:
|
|
self.last_seen = ts
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"hostname": self.hostname,
|
|
"ips": sorted(self.ips),
|
|
"users": sorted(self.users),
|
|
"os": sorted(self.os_versions),
|
|
"mac_addresses": sorted(self.mac_addresses),
|
|
"protocols": sorted(self.protocols),
|
|
"open_ports": sorted(self.open_ports, key=lambda p: int(p) if p.isdigit() else 0),
|
|
"remote_targets": sorted(self.remote_targets),
|
|
"datasets": sorted(self.datasets),
|
|
"connection_count": self.connection_count,
|
|
"first_seen": self.first_seen.isoformat() if self.first_seen else None,
|
|
"last_seen": self.last_seen.isoformat() if self.last_seen else None,
|
|
}
|
|
|
|
|
|
async def build_network_picture(
|
|
db: AsyncSession,
|
|
hunt_id: str,
|
|
) -> dict[str, Any]:
|
|
"""Build a deduplicated host inventory for all datasets in a hunt.
|
|
|
|
Returns:
|
|
{
|
|
"hosts": [ {hostname, ips[], users[], os[], ...}, ... ],
|
|
"summary": { total_hosts, total_connections, total_unique_ips, datasets_scanned }
|
|
}
|
|
"""
|
|
# 1. Get all datasets in this hunt
|
|
ds_result = await db.execute(
|
|
select(Dataset)
|
|
.where(Dataset.hunt_id == hunt_id)
|
|
.order_by(Dataset.created_at)
|
|
)
|
|
ds_list: Sequence[Dataset] = ds_result.scalars().all()
|
|
|
|
if not ds_list:
|
|
return {
|
|
"hosts": [],
|
|
"summary": {
|
|
"total_hosts": 0,
|
|
"total_connections": 0,
|
|
"total_unique_ips": 0,
|
|
"datasets_scanned": 0,
|
|
},
|
|
}
|
|
|
|
# 2. Stream rows and aggregate into host buckets
|
|
buckets: dict[str, _HostBucket] = {} # key = uppercase hostname or IP
|
|
|
|
for ds in ds_list:
|
|
ds_name = ds.name or ds.filename
|
|
offset = 0
|
|
while True:
|
|
stmt = (
|
|
select(DatasetRow)
|
|
.where(DatasetRow.dataset_id == ds.id)
|
|
.order_by(DatasetRow.row_index)
|
|
.limit(ROW_BATCH)
|
|
.offset(offset)
|
|
)
|
|
result = await db.execute(stmt)
|
|
rows: Sequence[DatasetRow] = result.scalars().all()
|
|
if not rows:
|
|
break
|
|
|
|
for dr in rows:
|
|
norm = dr.normalized_data or dr.data or {}
|
|
|
|
# Determine grouping key: hostname preferred, else src_ip/ip_address
|
|
host_val = ""
|
|
for k in _HOST_KEYS:
|
|
host_val = _clean(norm.get(k))
|
|
if host_val:
|
|
break
|
|
if not host_val:
|
|
for k in ("src_ip", "ip_address"):
|
|
host_val = _clean(norm.get(k))
|
|
if host_val:
|
|
break
|
|
if not host_val:
|
|
# Row has no host identifier — skip
|
|
continue
|
|
|
|
bucket_key = host_val.upper()
|
|
if bucket_key not in buckets:
|
|
buckets[bucket_key] = _HostBucket(host_val)
|
|
|
|
buckets[bucket_key].ingest(norm, ds_name)
|
|
|
|
offset += ROW_BATCH
|
|
|
|
# 3. Convert to sorted list (by connection count descending)
|
|
hosts_raw = sorted(buckets.values(), key=lambda b: b.connection_count, reverse=True)
|
|
if len(hosts_raw) > MAX_HOSTS:
|
|
hosts_raw = hosts_raw[:MAX_HOSTS]
|
|
|
|
hosts = [b.to_dict() for b in hosts_raw]
|
|
|
|
# 4. Summary stats
|
|
all_ips: set[str] = set()
|
|
total_conns = 0
|
|
for b in hosts_raw:
|
|
all_ips.update(b.ips)
|
|
total_conns += b.connection_count
|
|
|
|
return {
|
|
"hosts": hosts,
|
|
"summary": {
|
|
"total_hosts": len(hosts),
|
|
"total_connections": total_conns,
|
|
"total_unique_ips": len(all_ips),
|
|
"datasets_scanned": len(ds_list),
|
|
},
|
|
}
|