mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
363 lines
11 KiB
Python
363 lines
11 KiB
Python
"""LLM providers — real implementations for Ollama nodes and Open WebUI cluster.
|
|
|
|
Three providers:
|
|
- OllamaProvider: Direct calls to Ollama on Wile/Roadrunner via Tailscale
|
|
- OpenWebUIProvider: Calls to the Open WebUI cluster (OpenAI-compatible)
|
|
- EmbeddingProvider: Embedding generation via Ollama /api/embeddings
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import AsyncIterator
|
|
|
|
import httpx
|
|
|
|
from app.config import settings
|
|
from .registry import ModelEntry, Node
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Shared HTTP client with reasonable timeouts
|
|
_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
def _get_client() -> httpx.AsyncClient:
|
|
global _client
|
|
if _client is None or _client.is_closed:
|
|
_client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
|
|
limits=httpx.Limits(max_connections=20, max_keepalive_connections=10),
|
|
)
|
|
return _client
|
|
|
|
|
|
async def cleanup_client():
|
|
global _client
|
|
if _client and not _client.is_closed:
|
|
await _client.aclose()
|
|
_client = None
|
|
|
|
|
|
def _ollama_url(node: Node) -> str:
|
|
"""Get the Ollama base URL for a node."""
|
|
if node == Node.WILE:
|
|
return settings.wile_url
|
|
elif node == Node.ROADRUNNER:
|
|
return settings.roadrunner_url
|
|
else:
|
|
raise ValueError(f"No direct Ollama URL for node: {node}")
|
|
|
|
|
|
# ── Ollama Provider ──────────────────────────────────────────────────
|
|
|
|
|
|
class OllamaProvider:
|
|
"""Direct Ollama API calls to Wile or Roadrunner."""
|
|
|
|
def __init__(self, model: str, node: Node):
|
|
self.model = model
|
|
self.node = node
|
|
self.base_url = _ollama_url(node)
|
|
|
|
async def generate(
|
|
self,
|
|
prompt: str,
|
|
system: str = "",
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> dict:
|
|
"""Generate a completion. Returns dict with 'response', 'model', 'total_duration', etc."""
|
|
client = _get_client()
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"num_predict": max_tokens,
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
if system:
|
|
payload["system"] = system
|
|
|
|
start = time.monotonic()
|
|
try:
|
|
resp = await client.post(
|
|
f"{self.base_url}/api/generate",
|
|
json=payload,
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
latency_ms = int((time.monotonic() - start) * 1000)
|
|
data["_latency_ms"] = latency_ms
|
|
data["_node"] = self.node.value
|
|
logger.info(
|
|
f"Ollama [{self.node.value}] {self.model}: "
|
|
f"{latency_ms}ms, {data.get('eval_count', '?')} tokens"
|
|
)
|
|
return data
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Ollama HTTP error [{self.node.value}]: {e.response.status_code} {e.response.text[:200]}")
|
|
raise
|
|
except httpx.ConnectError as e:
|
|
logger.error(f"Cannot reach Ollama on {self.node.value} ({self.base_url}): {e}")
|
|
raise
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> dict:
|
|
"""Chat completion via Ollama /api/chat."""
|
|
client = _get_client()
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"options": {
|
|
"num_predict": max_tokens,
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
|
|
start = time.monotonic()
|
|
resp = await client.post(f"{self.base_url}/api/chat", json=payload)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
data["_latency_ms"] = int((time.monotonic() - start) * 1000)
|
|
data["_node"] = self.node.value
|
|
return data
|
|
|
|
async def generate_stream(
|
|
self,
|
|
prompt: str,
|
|
system: str = "",
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream tokens from Ollama."""
|
|
client = _get_client()
|
|
payload = {
|
|
"model": self.model,
|
|
"prompt": prompt,
|
|
"stream": True,
|
|
"options": {
|
|
"num_predict": max_tokens,
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
if system:
|
|
payload["system"] = system
|
|
|
|
async with client.stream(
|
|
"POST", f"{self.base_url}/api/generate", json=payload
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for line in resp.aiter_lines():
|
|
if line.strip():
|
|
try:
|
|
chunk = json.loads(line)
|
|
token = chunk.get("response", "")
|
|
if token:
|
|
yield token
|
|
if chunk.get("done"):
|
|
break
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
async def is_available(self) -> bool:
|
|
"""Ping the Ollama node."""
|
|
try:
|
|
client = _get_client()
|
|
resp = await client.get(f"{self.base_url}/api/tags", timeout=5)
|
|
return resp.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ── Open WebUI Provider (OpenAI-compatible) ───────────────────────────
|
|
|
|
|
|
class OpenWebUIProvider:
|
|
"""Calls to Open WebUI cluster at ai.guapo613.beer.
|
|
|
|
Uses the OpenAI-compatible /v1/chat/completions endpoint.
|
|
"""
|
|
|
|
def __init__(self, model: str = ""):
|
|
self.model = model or settings.DEFAULT_FAST_MODEL
|
|
self.base_url = settings.OPENWEBUI_URL.rstrip("/")
|
|
self.api_key = settings.OPENWEBUI_API_KEY
|
|
|
|
def _headers(self) -> dict:
|
|
h = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
h["Authorization"] = f"Bearer {self.api_key}"
|
|
return h
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict],
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> dict:
|
|
"""Chat completion via OpenAI-compatible endpoint."""
|
|
client = _get_client()
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"stream": False,
|
|
}
|
|
|
|
start = time.monotonic()
|
|
resp = await client.post(
|
|
f"{self.base_url}/v1/chat/completions",
|
|
json=payload,
|
|
headers=self._headers(),
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
latency_ms = int((time.monotonic() - start) * 1000)
|
|
|
|
# Normalize to our format
|
|
content = ""
|
|
if data.get("choices"):
|
|
content = data["choices"][0].get("message", {}).get("content", "")
|
|
|
|
result = {
|
|
"response": content,
|
|
"model": data.get("model", self.model),
|
|
"_latency_ms": latency_ms,
|
|
"_node": "cluster",
|
|
"_usage": data.get("usage", {}),
|
|
}
|
|
logger.info(
|
|
f"OpenWebUI cluster {self.model}: {latency_ms}ms"
|
|
)
|
|
return result
|
|
|
|
async def generate(
|
|
self,
|
|
prompt: str,
|
|
system: str = "",
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> dict:
|
|
"""Convert prompt-style call to chat format."""
|
|
messages = []
|
|
if system:
|
|
messages.append({"role": "system", "content": system})
|
|
messages.append({"role": "user", "content": prompt})
|
|
return await self.chat(messages, max_tokens, temperature)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict],
|
|
max_tokens: int = 2048,
|
|
temperature: float = 0.3,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream tokens from OpenWebUI."""
|
|
client = _get_client()
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"stream": True,
|
|
}
|
|
|
|
async with client.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/chat/completions",
|
|
json=payload,
|
|
headers=self._headers(),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for line in resp.aiter_lines():
|
|
if line.startswith("data: "):
|
|
data_str = line[6:].strip()
|
|
if data_str == "[DONE]":
|
|
break
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
token = delta.get("content", "")
|
|
if token:
|
|
yield token
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
async def is_available(self) -> bool:
|
|
"""Check if Open WebUI is reachable."""
|
|
try:
|
|
client = _get_client()
|
|
resp = await client.get(
|
|
f"{self.base_url}/v1/models",
|
|
headers=self._headers(),
|
|
timeout=5,
|
|
)
|
|
return resp.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ── Embedding Provider ────────────────────────────────────────────────
|
|
|
|
|
|
class EmbeddingProvider:
|
|
"""Generate embeddings via Ollama /api/embeddings."""
|
|
|
|
def __init__(self, model: str = "", node: Node = Node.ROADRUNNER):
|
|
self.model = model or settings.DEFAULT_EMBEDDING_MODEL
|
|
self.node = node
|
|
self.base_url = _ollama_url(node)
|
|
|
|
async def embed(self, text: str) -> list[float]:
|
|
"""Get embedding vector for a single text."""
|
|
client = _get_client()
|
|
resp = await client.post(
|
|
f"{self.base_url}/api/embeddings",
|
|
json={"model": self.model, "prompt": text},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data.get("embedding", [])
|
|
|
|
async def embed_batch(self, texts: list[str], concurrency: int = 5) -> list[list[float]]:
|
|
"""Embed multiple texts with controlled concurrency."""
|
|
sem = asyncio.Semaphore(concurrency)
|
|
|
|
async def _embed_one(t: str) -> list[float]:
|
|
async with sem:
|
|
return await self.embed(t)
|
|
|
|
return await asyncio.gather(*[_embed_one(t) for t in texts])
|
|
|
|
|
|
# ── Health check for all nodes ────────────────────────────────────────
|
|
|
|
|
|
async def check_all_nodes() -> dict:
|
|
"""Check availability of all LLM nodes."""
|
|
wile = OllamaProvider("", Node.WILE)
|
|
roadrunner = OllamaProvider("", Node.ROADRUNNER)
|
|
cluster = OpenWebUIProvider()
|
|
|
|
wile_ok, rr_ok, cl_ok = await asyncio.gather(
|
|
wile.is_available(),
|
|
roadrunner.is_available(),
|
|
cluster.is_available(),
|
|
return_exceptions=True,
|
|
)
|
|
|
|
return {
|
|
"wile": {"available": wile_ok is True, "url": settings.wile_url},
|
|
"roadrunner": {"available": rr_ok is True, "url": settings.roadrunner_url},
|
|
"cluster": {"available": cl_ok is True, "url": settings.OPENWEBUI_URL},
|
|
}
|