Files
ThreatHunt/backend/app/agents/providers_v2.py
mblanke 9b98ab9614 feat: interactive network map, IOC highlighting, AUP hunt selector, type filters
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover
- NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform
- NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types
- NetworkMap: brighter colors, 20% smaller nodes
- DatasetViewer: IOC columns highlighted with colored headers + cell tinting
- AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all
- Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration
- Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade)
- Add OS column mapping to normalizer
- Full backend services, DB models, alembic migrations, new routes
- New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc.
- Docker Compose deployment with nginx reverse proxy
2026-02-19 15:41:15 -05:00

363 lines
11 KiB
Python

"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
Three providers:
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
"""
import asyncio
import json
import logging
import time
from typing import AsyncIterator
import httpx
from app.config import settings
from .registry import ModelEntry, Node
logger = logging.getLogger(__name__)
# Shared HTTP client with reasonable timeouts
_client: httpx.AsyncClient | None = None
def _get_client() -> httpx.AsyncClient:
global _client
if _client is None or _client.is_closed:
_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
)
return _client
async def cleanup_client():
global _client
if _client and not _client.is_closed:
await _client.aclose()
_client = None
def _ollama_url(node: Node) -> str:
"""Get the Ollama base URL for a node."""
if node == Node.WILE:
return settings.wile_url
elif node == Node.ROADRUNNER:
return settings.roadrunner_url
else:
raise ValueError(f"No direct Ollama URL for node: {node}")
# ── Ollama Provider ──────────────────────────────────────────────────
class OllamaProvider:
"""Direct Ollama API calls to Wile or Roadrunner."""
def __init__(self, model: str, node: Node):
self.model = model
self.node = node
self.base_url = _ollama_url(node)
async def generate(
self,
prompt: str,
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.3,
) -> dict:
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
client = _get_client()
payload = {
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
}
if system:
payload["system"] = system
start = time.monotonic()
try:
resp = await client.post(
f"{self.base_url}/api/generate",
json=payload,
)
resp.raise_for_status()
data = resp.json()
latency_ms = int((time.monotonic() - start) * 1000)
data["_latency_ms"] = latency_ms
data["_node"] = self.node.value
logger.info(
f"Ollama [{self.node.value}] {self.model}: "
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
)
return data
except httpx.HTTPStatusError as e:
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
raise
except httpx.ConnectError as e:
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
raise
async def chat(
self,
messages: list[dict],
max_tokens: int = 2048,
temperature: float = 0.3,
) -> dict:
"""Chat completion via Ollama /api/chat."""
client = _get_client()
payload = {
"model": self.model,
"messages": messages,
"stream": False,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
}
start = time.monotonic()
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
resp.raise_for_status()
data = resp.json()
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
data["_node"] = self.node.value
return data
async def generate_stream(
self,
prompt: str,
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.3,
) -> AsyncIterator[str]:
"""Stream tokens from Ollama."""
client = _get_client()
payload = {
"model": self.model,
"prompt": prompt,
"stream": True,
"options": {
"num_predict": max_tokens,
"temperature": temperature,
},
}
if system:
payload["system"] = system
async with client.stream(
"POST", f"{self.base_url}/api/generate", json=payload
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.strip():
try:
chunk = json.loads(line)
token = chunk.get("response", "")
if token:
yield token
if chunk.get("done"):
break
except json.JSONDecodeError:
continue
async def is_available(self) -> bool:
"""Ping the Ollama node."""
try:
client = _get_client()
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
return resp.status_code == 200
except Exception:
return False
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
class OpenWebUIProvider:
"""Calls to Open WebUI cluster at ai.guapo613.beer.
Uses the OpenAI-compatible /v1/chat/completions endpoint.
"""
def __init__(self, model: str = ""):
self.model = model or settings.DEFAULT_FAST_MODEL
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
self.api_key = settings.OPENWEBUI_API_KEY
def _headers(self) -> dict:
h = {"Content-Type": "application/json"}
if self.api_key:
h["Authorization"] = f"Bearer {self.api_key}"
return h
async def chat(
self,
messages: list[dict],
max_tokens: int = 2048,
temperature: float = 0.3,
) -> dict:
"""Chat completion via OpenAI-compatible endpoint."""
client = _get_client()
payload = {
"model": self.model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False,
}
start = time.monotonic()
resp = await client.post(
f"{self.base_url}/v1/chat/completions",
json=payload,
headers=self._headers(),
)
resp.raise_for_status()
data = resp.json()
latency_ms = int((time.monotonic() - start) * 1000)
# Normalize to our format
content = ""
if data.get("choices"):
content = data["choices"][0].get("message", {}).get("content", "")
result = {
"response": content,
"model": data.get("model", self.model),
"_latency_ms": latency_ms,
"_node": "cluster",
"_usage": data.get("usage", {}),
}
logger.info(
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
)
return result
async def generate(
self,
prompt: str,
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.3,
) -> dict:
"""Convert prompt-style call to chat format."""
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
return await self.chat(messages, max_tokens, temperature)
async def chat_stream(
self,
messages: list[dict],
max_tokens: int = 2048,
temperature: float = 0.3,
) -> AsyncIterator[str]:
"""Stream tokens from OpenWebUI."""
client = _get_client()
payload = {
"model": self.model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True,
}
async with client.stream(
"POST",
f"{self.base_url}/v1/chat/completions",
json=payload,
headers=self._headers(),
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]":
break
try:
chunk = json.loads(data_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
token = delta.get("content", "")
if token:
yield token
except json.JSONDecodeError:
continue
async def is_available(self) -> bool:
"""Check if Open WebUI is reachable."""
try:
client = _get_client()
resp = await client.get(
f"{self.base_url}/v1/models",
headers=self._headers(),
timeout=5,
)
return resp.status_code == 200
except Exception:
return False
# ── Embedding Provider ────────────────────────────────────────────────
class EmbeddingProvider:
"""Generate embeddings via Ollama /api/embeddings."""
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
self.node = node
self.base_url = _ollama_url(node)
async def embed(self, text: str) -> list[float]:
"""Get embedding vector for a single text."""
client = _get_client()
resp = await client.post(
f"{self.base_url}/api/embeddings",
json={"model": self.model, "prompt": text},
)
resp.raise_for_status()
data = resp.json()
return data.get("embedding", [])
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
"""Embed multiple texts with controlled concurrency."""
sem = asyncio.Semaphore(concurrency)
async def _embed_one(t: str) -> list[float]:
async with sem:
return await self.embed(t)
return await asyncio.gather(*[_embed_one(t) for t in texts])
# ── Health check for all nodes ────────────────────────────────────────
async def check_all_nodes() -> dict:
"""Check availability of all LLM nodes."""
wile = OllamaProvider("", Node.WILE)
roadrunner = OllamaProvider("", Node.ROADRUNNER)
cluster = OpenWebUIProvider()
wile_ok, rr_ok, cl_ok = await asyncio.gather(
wile.is_available(),
roadrunner.is_available(),
cluster.is_available(),
return_exceptions=True,
)
return {
"wile": {"available": wile_ok is True, "url": settings.wile_url},
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
}