mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
This commit is contained in:
362
backend/app/agents/providers_v2.py
Normal file
362
backend/app/agents/providers_v2.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
|
||||
|
||||
Three providers:
|
||||
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
|
||||
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
|
||||
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
from .registry import ModelEntry, Node
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared HTTP client with reasonable timeouts
|
||||
_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def _get_client() -> httpx.AsyncClient:
|
||||
global _client
|
||||
if _client is None or _client.is_closed:
|
||||
_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
|
||||
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
|
||||
)
|
||||
return _client
|
||||
|
||||
|
||||
async def cleanup_client():
|
||||
global _client
|
||||
if _client and not _client.is_closed:
|
||||
await _client.aclose()
|
||||
_client = None
|
||||
|
||||
|
||||
def _ollama_url(node: Node) -> str:
|
||||
"""Get the Ollama base URL for a node."""
|
||||
if node == Node.WILE:
|
||||
return settings.wile_url
|
||||
elif node == Node.ROADRUNNER:
|
||||
return settings.roadrunner_url
|
||||
else:
|
||||
raise ValueError(f"No direct Ollama URL for node: {node}")
|
||||
|
||||
|
||||
# ── Ollama Provider ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class OllamaProvider:
|
||||
"""Direct Ollama API calls to Wile or Roadrunner."""
|
||||
|
||||
def __init__(self, model: str, node: Node):
|
||||
self.model = model
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/generate",
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
data["_latency_ms"] = latency_ms
|
||||
data["_node"] = self.node.value
|
||||
logger.info(
|
||||
f"Ollama [{self.node.value}] {self.model}: "
|
||||
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
|
||||
)
|
||||
return data
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
|
||||
raise
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
|
||||
raise
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via Ollama /api/chat."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
|
||||
data["_node"] = self.node.value
|
||||
return data
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from Ollama."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"num_predict": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with client.stream(
|
||||
"POST", f"{self.base_url}/api/generate", json=payload
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.strip():
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
token = chunk.get("response", "")
|
||||
if token:
|
||||
yield token
|
||||
if chunk.get("done"):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Ping the Ollama node."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
|
||||
|
||||
|
||||
class OpenWebUIProvider:
|
||||
"""Calls to Open WebUI cluster at ai.guapo613.beer.
|
||||
|
||||
Uses the OpenAI-compatible /v1/chat/completions endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = ""):
|
||||
self.model = model or settings.DEFAULT_FAST_MODEL
|
||||
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
|
||||
self.api_key = settings.OPENWEBUI_API_KEY
|
||||
|
||||
def _headers(self) -> dict:
|
||||
h = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
h["Authorization"] = f"Bearer {self.api_key}"
|
||||
return h
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Chat completion via OpenAI-compatible endpoint."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Normalize to our format
|
||||
content = ""
|
||||
if data.get("choices"):
|
||||
content = data["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
result = {
|
||||
"response": content,
|
||||
"model": data.get("model", self.model),
|
||||
"_latency_ms": latency_ms,
|
||||
"_node": "cluster",
|
||||
"_usage": data.get("usage", {}),
|
||||
}
|
||||
logger.info(
|
||||
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
|
||||
)
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> dict:
|
||||
"""Convert prompt-style call to chat format."""
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
return await self.chat(messages, max_tokens, temperature)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.3,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream tokens from OpenWebUI."""
|
||||
client = _get_client()
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers(),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
token = delta.get("content", "")
|
||||
if token:
|
||||
yield token
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
"""Check if Open WebUI is reachable."""
|
||||
try:
|
||||
client = _get_client()
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/v1/models",
|
||||
headers=self._headers(),
|
||||
timeout=5,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# ── Embedding Provider ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Generate embeddings via Ollama /api/embeddings."""
|
||||
|
||||
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
|
||||
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
|
||||
self.node = node
|
||||
self.base_url = _ollama_url(node)
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
client = _get_client()
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return data.get("embedding", [])
|
||||
|
||||
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
|
||||
"""Embed multiple texts with controlled concurrency."""
|
||||
sem = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def _embed_one(t: str) -> list[float]:
|
||||
async with sem:
|
||||
return await self.embed(t)
|
||||
|
||||
return await asyncio.gather(*[_embed_one(t) for t in texts])
|
||||
|
||||
|
||||
# ── Health check for all nodes ────────────────────────────────────────
|
||||
|
||||
|
||||
async def check_all_nodes() -> dict:
|
||||
"""Check availability of all LLM nodes."""
|
||||
wile = OllamaProvider("", Node.WILE)
|
||||
roadrunner = OllamaProvider("", Node.ROADRUNNER)
|
||||
cluster = OpenWebUIProvider()
|
||||
|
||||
wile_ok, rr_ok, cl_ok = await asyncio.gather(
|
||||
wile.is_available(),
|
||||
roadrunner.is_available(),
|
||||
cluster.is_available(),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"wile": {"available": wile_ok is True, "url": settings.wile_url},
|
||||
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
|
||||
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
|
||||
}
|
||||
Reference in New Issue
Block a user