mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 05:50:21 -05:00
- Rewrote NetworkMap to use deduplicated host inventory (163 hosts from 394K rows) - New host_inventory.py service: scans datasets, groups by FQDN/ClientId, extracts IPs/users/OS - New /api/network/host-inventory endpoint - Added AnalysisDashboard with 6 tabs (IOC, anomaly, host profile, query, triage, reports) - Added 16 analysis API endpoints with job queue and load balancer - Added 4 AI/analysis ORM models (ProcessingJob, AnalysisResult, HostProfile, IOCEntry) - Filters system accounts (DWM-*, UMFD-*, LOCAL/NETWORK SERVICE) - Infers OS from hostname patterns (W10-* -> Windows 10) - Canvas 2D force-directed graph with host/external-IP node types - Click popover shows hostname, FQDN, IPs, OS, users, datasets, connections
193 lines
6.1 KiB
Python
193 lines
6.1 KiB
Python
"""Smart load balancer for Wile & Roadrunner LLM nodes.
|
|
|
|
Tracks active jobs per node, health status, and routes new work
|
|
to the least-busy healthy node. Periodically pings both nodes
|
|
to maintain an up-to-date health map.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NodeId(str, Enum):
|
|
WILE = "wile"
|
|
ROADRUNNER = "roadrunner"
|
|
|
|
|
|
class WorkloadTier(str, Enum):
|
|
"""What kind of workload is this?"""
|
|
HEAVY = "heavy" # 70B models, deep analysis, reports
|
|
FAST = "fast" # 7-14B models, triage, quick queries
|
|
EMBEDDING = "embed" # bge-m3 embeddings
|
|
ANY = "any"
|
|
|
|
|
|
@dataclass
|
|
class NodeStatus:
|
|
node_id: NodeId
|
|
url: str
|
|
healthy: bool = True
|
|
last_check: float = 0.0
|
|
active_jobs: int = 0
|
|
total_completed: int = 0
|
|
total_errors: int = 0
|
|
avg_latency_ms: float = 0.0
|
|
_latencies: list[float] = field(default_factory=list)
|
|
|
|
def record_completion(self, latency_ms: float):
|
|
self.active_jobs = max(0, self.active_jobs - 1)
|
|
self.total_completed += 1
|
|
self._latencies.append(latency_ms)
|
|
# Rolling average of last 50
|
|
if len(self._latencies) > 50:
|
|
self._latencies = self._latencies[-50:]
|
|
self.avg_latency_ms = sum(self._latencies) / len(self._latencies)
|
|
|
|
def record_error(self):
|
|
self.active_jobs = max(0, self.active_jobs - 1)
|
|
self.total_errors += 1
|
|
|
|
def record_start(self):
|
|
self.active_jobs += 1
|
|
|
|
|
|
class LoadBalancer:
|
|
"""Routes LLM work to the least-busy healthy node.
|
|
|
|
Node capabilities:
|
|
- Wile: Heavy models (70B), code models (32B)
|
|
- Roadrunner: Fast models (7-14B), embeddings (bge-m3), vision
|
|
"""
|
|
|
|
# Which nodes can handle which tiers
|
|
TIER_NODES = {
|
|
WorkloadTier.HEAVY: [NodeId.WILE],
|
|
WorkloadTier.FAST: [NodeId.ROADRUNNER, NodeId.WILE],
|
|
WorkloadTier.EMBEDDING: [NodeId.ROADRUNNER],
|
|
WorkloadTier.ANY: [NodeId.ROADRUNNER, NodeId.WILE],
|
|
}
|
|
|
|
def __init__(self):
|
|
self._nodes: dict[NodeId, NodeStatus] = {
|
|
NodeId.WILE: NodeStatus(
|
|
node_id=NodeId.WILE,
|
|
url=f"http://{settings.WILE_HOST}:{settings.WILE_OLLAMA_PORT}",
|
|
),
|
|
NodeId.ROADRUNNER: NodeStatus(
|
|
node_id=NodeId.ROADRUNNER,
|
|
url=f"http://{settings.ROADRUNNER_HOST}:{settings.ROADRUNNER_OLLAMA_PORT}",
|
|
),
|
|
}
|
|
self._lock = asyncio.Lock()
|
|
self._health_task: Optional[asyncio.Task] = None
|
|
|
|
async def start_health_loop(self, interval: float = 30.0):
|
|
"""Start background health-check loop."""
|
|
if self._health_task and not self._health_task.done():
|
|
return
|
|
self._health_task = asyncio.create_task(self._health_loop(interval))
|
|
logger.info("Load balancer health loop started (%.0fs interval)", interval)
|
|
|
|
async def stop_health_loop(self):
|
|
if self._health_task:
|
|
self._health_task.cancel()
|
|
try:
|
|
await self._health_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._health_task = None
|
|
|
|
async def _health_loop(self, interval: float):
|
|
while True:
|
|
try:
|
|
await self.check_health()
|
|
except Exception as e:
|
|
logger.warning(f"Health check error: {e}")
|
|
await asyncio.sleep(interval)
|
|
|
|
async def check_health(self):
|
|
"""Ping both nodes and update status."""
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=5) as client:
|
|
for nid, status in self._nodes.items():
|
|
try:
|
|
resp = await client.get(f"{status.url}/api/tags")
|
|
status.healthy = resp.status_code == 200
|
|
except Exception:
|
|
status.healthy = False
|
|
status.last_check = time.time()
|
|
logger.debug(
|
|
f"Health: {nid.value} = {'OK' if status.healthy else 'DOWN'} "
|
|
f"(active={status.active_jobs})"
|
|
)
|
|
|
|
def select_node(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
|
"""Select the best node for a workload tier.
|
|
|
|
Strategy: among healthy nodes that support the tier,
|
|
pick the one with fewest active jobs.
|
|
Falls back to any node if none healthy.
|
|
"""
|
|
candidates = self.TIER_NODES.get(tier, [NodeId.ROADRUNNER, NodeId.WILE])
|
|
|
|
# Filter to healthy candidates
|
|
healthy = [
|
|
nid for nid in candidates
|
|
if self._nodes[nid].healthy
|
|
]
|
|
|
|
if not healthy:
|
|
logger.warning(f"No healthy nodes for tier {tier.value}, using first candidate")
|
|
healthy = candidates
|
|
|
|
# Pick least busy
|
|
best = min(healthy, key=lambda nid: self._nodes[nid].active_jobs)
|
|
return best
|
|
|
|
def acquire(self, tier: WorkloadTier = WorkloadTier.ANY) -> NodeId:
|
|
"""Select node and mark a job started."""
|
|
node = self.select_node(tier)
|
|
self._nodes[node].record_start()
|
|
logger.info(
|
|
f"LB: dispatched {tier.value} -> {node.value} "
|
|
f"(active={self._nodes[node].active_jobs})"
|
|
)
|
|
return node
|
|
|
|
def release(self, node: NodeId, latency_ms: float = 0, error: bool = False):
|
|
"""Mark a job completed on a node."""
|
|
status = self._nodes.get(node)
|
|
if not status:
|
|
return
|
|
if error:
|
|
status.record_error()
|
|
else:
|
|
status.record_completion(latency_ms)
|
|
|
|
def get_status(self) -> dict:
|
|
"""Get current load balancer status."""
|
|
return {
|
|
nid.value: {
|
|
"healthy": s.healthy,
|
|
"active_jobs": s.active_jobs,
|
|
"total_completed": s.total_completed,
|
|
"total_errors": s.total_errors,
|
|
"avg_latency_ms": round(s.avg_latency_ms, 1),
|
|
"last_check": s.last_check,
|
|
}
|
|
for nid, s in self._nodes.items()
|
|
}
|
|
|
|
|
|
# Singleton
|
|
lb = LoadBalancer() |