mirror of
https://github.com/mblanke/StrikePackageGPT.git
synced 2026-03-01 14:20:21 -05:00
feat: Add interactive installer and multi-endpoint LLM load balancing
- Add install.ps1 (PowerShell) and install.sh (Bash) interactive installers - Support local, networked, and cloud AI providers in installer - Add multi-Ollama endpoint configuration with high-speed NIC support - Implement load balancing strategies: round-robin, failover, random - Update LLM router with endpoint health checking and automatic failover - Add /endpoints API for monitoring all Ollama instances - Update docker-compose with OLLAMA_ENDPOINTS and LOAD_BALANCE_STRATEGY - Rebrand to GooseStrike with custom icon and flag assets
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
LLM Router Service
|
||||
Routes requests to different LLM providers (OpenAI, Anthropic, Ollama)
|
||||
Supports multiple Ollama endpoints with load balancing
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -8,11 +9,15 @@ from pydantic import BaseModel
|
||||
from typing import Optional, Literal
|
||||
import httpx
|
||||
import os
|
||||
import random
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
app = FastAPI(
|
||||
title="LLM Router",
|
||||
description="Routes requests to multiple LLM providers",
|
||||
version="0.1.0"
|
||||
description="Routes requests to multiple LLM providers with load balancing",
|
||||
version="0.2.0"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -26,7 +31,22 @@ app.add_middleware(
|
||||
# Configuration from environment
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://192.168.1.50:11434")
|
||||
# Support multiple Ollama endpoints (comma-separated)
|
||||
OLLAMA_ENDPOINTS_STR = os.getenv("OLLAMA_ENDPOINTS", os.getenv("OLLAMA_BASE_URL", "http://192.168.1.50:11434"))
|
||||
OLLAMA_ENDPOINTS = [url.strip() for url in OLLAMA_ENDPOINTS_STR.split(",") if url.strip()]
|
||||
LOAD_BALANCE_STRATEGY = os.getenv("LOAD_BALANCE_STRATEGY", "round-robin") # round-robin, random, failover
|
||||
|
||||
@dataclass
|
||||
class EndpointHealth:
|
||||
url: str
|
||||
healthy: bool = True
|
||||
last_check: datetime = None
|
||||
failure_count: int = 0
|
||||
models: list = None
|
||||
|
||||
# Track endpoint health
|
||||
endpoint_health: dict[str, EndpointHealth] = {url: EndpointHealth(url=url, models=[]) for url in OLLAMA_ENDPOINTS}
|
||||
current_endpoint_index = 0
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -52,33 +72,119 @@ class ChatResponse(BaseModel):
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "service": "llm-router"}
|
||||
return {"status": "healthy", "service": "llm-router", "endpoints": len(OLLAMA_ENDPOINTS)}
|
||||
|
||||
|
||||
async def check_endpoint_health(url: str) -> tuple[bool, list]:
|
||||
"""Check if an Ollama endpoint is healthy and get its models"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{url}/api/tags", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = [m["name"] for m in data.get("models", [])]
|
||||
return True, models
|
||||
except Exception:
|
||||
pass
|
||||
return False, []
|
||||
|
||||
|
||||
async def get_healthy_endpoint() -> Optional[str]:
|
||||
"""Get a healthy Ollama endpoint based on load balancing strategy"""
|
||||
global current_endpoint_index
|
||||
|
||||
# Refresh health status for stale checks (older than 30 seconds)
|
||||
now = datetime.now()
|
||||
for url, health in endpoint_health.items():
|
||||
if health.last_check is None or (now - health.last_check) > timedelta(seconds=30):
|
||||
is_healthy, models = await check_endpoint_health(url)
|
||||
health.healthy = is_healthy
|
||||
health.models = models
|
||||
health.last_check = now
|
||||
if is_healthy:
|
||||
health.failure_count = 0
|
||||
|
||||
healthy_endpoints = [url for url, h in endpoint_health.items() if h.healthy]
|
||||
|
||||
if not healthy_endpoints:
|
||||
return None
|
||||
|
||||
if LOAD_BALANCE_STRATEGY == "random":
|
||||
return random.choice(healthy_endpoints)
|
||||
elif LOAD_BALANCE_STRATEGY == "failover":
|
||||
# Always use first available healthy endpoint
|
||||
return healthy_endpoints[0]
|
||||
else: # round-robin (default)
|
||||
# Find next healthy endpoint in rotation
|
||||
for _ in range(len(OLLAMA_ENDPOINTS)):
|
||||
current_endpoint_index = (current_endpoint_index + 1) % len(OLLAMA_ENDPOINTS)
|
||||
url = OLLAMA_ENDPOINTS[current_endpoint_index]
|
||||
if url in healthy_endpoints:
|
||||
return url
|
||||
return healthy_endpoints[0]
|
||||
|
||||
|
||||
@app.get("/providers")
|
||||
async def list_providers():
|
||||
"""List available LLM providers and their status"""
|
||||
# Dynamically fetch Ollama models
|
||||
ollama_models = []
|
||||
ollama_available = False
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
ollama_models = [m["name"] for m in data.get("models", [])]
|
||||
ollama_available = True
|
||||
except Exception:
|
||||
ollama_models = ["llama3", "mistral", "codellama"] # fallback
|
||||
# Check all Ollama endpoints
|
||||
ollama_info = []
|
||||
all_models = set()
|
||||
any_available = False
|
||||
|
||||
for url in OLLAMA_ENDPOINTS:
|
||||
is_healthy, models = await check_endpoint_health(url)
|
||||
endpoint_health[url].healthy = is_healthy
|
||||
endpoint_health[url].models = models
|
||||
endpoint_health[url].last_check = datetime.now()
|
||||
|
||||
ollama_info.append({
|
||||
"url": url,
|
||||
"available": is_healthy,
|
||||
"models": models
|
||||
})
|
||||
if is_healthy:
|
||||
any_available = True
|
||||
all_models.update(models)
|
||||
|
||||
providers = {
|
||||
"openai": {"available": bool(OPENAI_API_KEY), "models": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]},
|
||||
"anthropic": {"available": bool(ANTHROPIC_API_KEY), "models": ["claude-sonnet-4-20250514", "claude-3-5-haiku-20241022"]},
|
||||
"ollama": {"available": ollama_available, "base_url": OLLAMA_BASE_URL, "models": ollama_models}
|
||||
"ollama": {
|
||||
"available": any_available,
|
||||
"endpoints": ollama_info,
|
||||
"load_balance_strategy": LOAD_BALANCE_STRATEGY,
|
||||
"models": list(all_models) if all_models else ["llama3", "mistral", "codellama"]
|
||||
}
|
||||
}
|
||||
return providers
|
||||
|
||||
|
||||
@app.get("/endpoints")
|
||||
async def list_endpoints():
|
||||
"""List all Ollama endpoints with detailed status"""
|
||||
results = []
|
||||
for url in OLLAMA_ENDPOINTS:
|
||||
is_healthy, models = await check_endpoint_health(url)
|
||||
endpoint_health[url].healthy = is_healthy
|
||||
endpoint_health[url].models = models
|
||||
endpoint_health[url].last_check = datetime.now()
|
||||
|
||||
results.append({
|
||||
"url": url,
|
||||
"healthy": is_healthy,
|
||||
"models": models,
|
||||
"failure_count": endpoint_health[url].failure_count
|
||||
})
|
||||
|
||||
return {
|
||||
"strategy": LOAD_BALANCE_STRATEGY,
|
||||
"endpoints": results,
|
||||
"healthy_count": sum(1 for r in results if r["healthy"]),
|
||||
"total_count": len(results)
|
||||
}
|
||||
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest):
|
||||
"""Route chat request to specified LLM provider"""
|
||||
@@ -174,11 +280,16 @@ async def _call_anthropic(request: ChatRequest) -> ChatResponse:
|
||||
|
||||
|
||||
async def _call_ollama(request: ChatRequest) -> ChatResponse:
|
||||
"""Call Ollama API (local models)"""
|
||||
"""Call Ollama API with load balancing across endpoints"""
|
||||
endpoint = await get_healthy_endpoint()
|
||||
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=503, detail="No healthy Ollama endpoints available")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{OLLAMA_BASE_URL}/api/chat",
|
||||
f"{endpoint}/api/chat",
|
||||
json={
|
||||
"model": request.model,
|
||||
"messages": [m.model_dump() for m in request.messages],
|
||||
@@ -192,8 +303,15 @@ async def _call_ollama(request: ChatRequest) -> ChatResponse:
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
# Mark endpoint as failed
|
||||
endpoint_health[endpoint].failure_count += 1
|
||||
if endpoint_health[endpoint].failure_count >= 3:
|
||||
endpoint_health[endpoint].healthy = False
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
# Reset failure count on success
|
||||
endpoint_health[endpoint].failure_count = 0
|
||||
|
||||
data = response.json()
|
||||
return ChatResponse(
|
||||
provider="ollama",
|
||||
@@ -201,11 +319,22 @@ async def _call_ollama(request: ChatRequest) -> ChatResponse:
|
||||
content=data["message"]["content"],
|
||||
usage={
|
||||
"prompt_tokens": data.get("prompt_eval_count", 0),
|
||||
"completion_tokens": data.get("eval_count", 0)
|
||||
"completion_tokens": data.get("eval_count", 0),
|
||||
"endpoint": endpoint
|
||||
}
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
raise HTTPException(status_code=503, detail="Ollama service not available")
|
||||
# Mark endpoint as unhealthy
|
||||
endpoint_health[endpoint].healthy = False
|
||||
endpoint_health[endpoint].failure_count += 1
|
||||
|
||||
# Try another endpoint if available
|
||||
other_endpoint = await get_healthy_endpoint()
|
||||
if other_endpoint and other_endpoint != endpoint:
|
||||
# Recursive call will use different endpoint
|
||||
return await _call_ollama(request)
|
||||
|
||||
raise HTTPException(status_code=503, detail="All Ollama endpoints unavailable")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user