mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
Implement Phase 5: Distributed LLM Routing Architecture
Co-authored-by: mblanke <9078342+mblanke@users.noreply.github.com>
This commit is contained in:
250
backend/app/api/routes/llm.py
Normal file
250
backend/app/api/routes/llm.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.deps import get_current_active_user, require_role
|
||||
from app.core.llm_router import get_llm_router, TaskType
|
||||
from app.core.job_scheduler import get_job_scheduler, Job, NodeStatus
|
||||
from app.core.llm_pool import get_llm_pool
|
||||
from app.core.merger_agent import get_merger_agent, MergeStrategy
|
||||
from app.models.user import User
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
"""Request for LLM processing"""
|
||||
prompt: str
|
||||
task_hints: Optional[List[str]] = []
|
||||
requires_parallel: bool = False
|
||||
requires_chaining: bool = False
|
||||
batch_size: int = 1
|
||||
operations: Optional[List[str]] = []
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class LLMResponse(BaseModel):
|
||||
"""Response from LLM processing"""
|
||||
job_id: str
|
||||
result: Any
|
||||
execution_mode: str
|
||||
models_used: List[str]
|
||||
strategy: Optional[str] = None
|
||||
|
||||
|
||||
class NodeStatusUpdate(BaseModel):
|
||||
"""Update node status"""
|
||||
node_id: str
|
||||
vram_used_gb: Optional[int] = None
|
||||
compute_utilization: Optional[float] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/process", response_model=Dict[str, Any])
|
||||
async def process_llm_request(
|
||||
request: LLMRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Process an LLM request through the distributed routing system
|
||||
|
||||
The request flows through:
|
||||
1. Router Agent - classifies and routes to appropriate model
|
||||
2. Job Scheduler - determines execution strategy
|
||||
3. LLM Pool - executes on appropriate endpoints
|
||||
4. Merger Agent - combines results if multiple models used
|
||||
"""
|
||||
# Step 1: Route the request
|
||||
router_agent = get_llm_router()
|
||||
routing_decision = router_agent.route_request(request.dict())
|
||||
|
||||
# Step 2: Schedule the job
|
||||
scheduler = get_job_scheduler()
|
||||
job = Job(
|
||||
job_id=f"job_{current_user.id}_{hash(request.prompt) % 10000}",
|
||||
model=routing_decision["model"],
|
||||
priority=routing_decision["priority"],
|
||||
estimated_vram_gb=10, # Estimate based on model
|
||||
requires_parallel=request.requires_parallel,
|
||||
requires_chaining=request.requires_chaining,
|
||||
payload=request.dict()
|
||||
)
|
||||
|
||||
scheduling_decision = await scheduler.schedule_job(job)
|
||||
|
||||
# Step 3: Execute on LLM pool
|
||||
pool = get_llm_pool()
|
||||
|
||||
if scheduling_decision["execution_mode"] == "parallel":
|
||||
# Execute on multiple nodes
|
||||
model_names = [routing_decision["model"]] * len(scheduling_decision["nodes"])
|
||||
results = await pool.call_multiple_models(
|
||||
model_names,
|
||||
request.prompt,
|
||||
request.parameters
|
||||
)
|
||||
|
||||
# Step 4: Merge results
|
||||
merger = get_merger_agent()
|
||||
final_result = merger.merge_results(
|
||||
results["results"],
|
||||
strategy=MergeStrategy.CONSENSUS
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"status": "completed",
|
||||
"routing": routing_decision,
|
||||
"scheduling": scheduling_decision,
|
||||
"result": final_result,
|
||||
"execution_mode": "parallel"
|
||||
}
|
||||
|
||||
elif scheduling_decision["execution_mode"] == "queued":
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"status": "queued",
|
||||
"queue_position": scheduling_decision["queue_position"],
|
||||
"message": "Job queued - no nodes available"
|
||||
}
|
||||
|
||||
else:
|
||||
# Single node execution
|
||||
result = await pool.call_model(
|
||||
routing_decision["model"],
|
||||
request.prompt,
|
||||
request.parameters
|
||||
)
|
||||
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"status": "completed",
|
||||
"routing": routing_decision,
|
||||
"scheduling": scheduling_decision,
|
||||
"result": result,
|
||||
"execution_mode": scheduling_decision["execution_mode"]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List all available LLM models in the pool
|
||||
"""
|
||||
pool = get_llm_pool()
|
||||
models = pool.list_available_models()
|
||||
|
||||
return {
|
||||
"models": models,
|
||||
"total": len(models)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/nodes")
|
||||
async def list_gpu_nodes(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List all GPU nodes and their status
|
||||
"""
|
||||
scheduler = get_job_scheduler()
|
||||
nodes = scheduler.get_available_nodes()
|
||||
|
||||
return {
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": node.node_id,
|
||||
"hostname": node.hostname,
|
||||
"vram_total_gb": node.vram_total_gb,
|
||||
"vram_used_gb": node.vram_used_gb,
|
||||
"vram_available_gb": node.vram_available_gb,
|
||||
"compute_utilization": node.compute_utilization,
|
||||
"status": node.status.value,
|
||||
"models_loaded": node.models_loaded
|
||||
}
|
||||
for node in scheduler.nodes.values()
|
||||
],
|
||||
"available_count": len(nodes)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/nodes/status")
|
||||
async def update_node_status(
|
||||
update: NodeStatusUpdate,
|
||||
current_user: User = Depends(require_role(["admin"])),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update GPU node status (admin only)
|
||||
"""
|
||||
scheduler = get_job_scheduler()
|
||||
|
||||
status_enum = None
|
||||
if update.status:
|
||||
try:
|
||||
status_enum = NodeStatus[update.status.upper()]
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid status: {update.status}"
|
||||
)
|
||||
|
||||
scheduler.update_node_status(
|
||||
update.node_id,
|
||||
vram_used_gb=update.vram_used_gb,
|
||||
compute_utilization=update.compute_utilization,
|
||||
status=status_enum
|
||||
)
|
||||
|
||||
return {"message": "Node status updated"}
|
||||
|
||||
|
||||
@router.get("/routing/rules")
|
||||
async def get_routing_rules(
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get current routing rules for task classification
|
||||
"""
|
||||
router_agent = get_llm_router()
|
||||
|
||||
return {
|
||||
"routing_rules": {
|
||||
task_type.value: {
|
||||
"model": rule["model"],
|
||||
"endpoint": rule["endpoint"],
|
||||
"priority": rule["priority"],
|
||||
"description": rule["description"]
|
||||
}
|
||||
for task_type, rule in router_agent.routing_rules.items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/test-classification")
|
||||
async def test_classification(
|
||||
request: LLMRequest,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Test task classification without executing
|
||||
"""
|
||||
router_agent = get_llm_router()
|
||||
task_type = router_agent.classify_request(request.dict())
|
||||
routing_decision = router_agent.route_request(request.dict())
|
||||
|
||||
return {
|
||||
"task_type": task_type.value,
|
||||
"routing_decision": routing_decision,
|
||||
"should_parallelize": router_agent.should_parallelize(request.dict()),
|
||||
"requires_chaining": router_agent.requires_serial_chaining(request.dict())
|
||||
}
|
||||
263
backend/app/core/job_scheduler.py
Normal file
263
backend/app/core/job_scheduler.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Job Scheduler
|
||||
|
||||
Manages job distribution across GPU nodes based on availability and load.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
|
||||
|
||||
class NodeStatus(Enum):
|
||||
"""Status of GPU node"""
|
||||
AVAILABLE = "available"
|
||||
BUSY = "busy"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUNode:
|
||||
"""Represents a GPU compute node"""
|
||||
node_id: str
|
||||
hostname: str
|
||||
port: int
|
||||
vram_total_gb: int
|
||||
vram_used_gb: int
|
||||
compute_utilization: float # 0.0 to 1.0
|
||||
status: NodeStatus
|
||||
models_loaded: List[str]
|
||||
|
||||
@property
|
||||
def vram_available_gb(self) -> int:
|
||||
"""Calculate available VRAM"""
|
||||
return self.vram_total_gb - self.vram_used_gb
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if node is available for work"""
|
||||
return self.status == NodeStatus.AVAILABLE and self.compute_utilization < 0.9
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Represents an LLM job"""
|
||||
job_id: str
|
||||
model: str
|
||||
priority: int
|
||||
estimated_vram_gb: int
|
||||
requires_parallel: bool
|
||||
requires_chaining: bool
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class JobScheduler:
|
||||
"""
|
||||
Job Scheduler - Manages distribution of LLM jobs across GPU nodes
|
||||
|
||||
Decides:
|
||||
- Which GB10 device is available
|
||||
- GPU load (VRAM, compute utilization)
|
||||
- Whether to parallelize across both nodes
|
||||
- Whether job requires serial reasoning (chained)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize job scheduler"""
|
||||
self.nodes: Dict[str, GPUNode] = {}
|
||||
self.job_queue: List[Job] = []
|
||||
self._initialize_nodes()
|
||||
|
||||
def _initialize_nodes(self):
|
||||
"""Initialize GPU node configuration"""
|
||||
# GB10 Node 1
|
||||
self.nodes["gb10-node-1"] = GPUNode(
|
||||
node_id="gb10-node-1",
|
||||
hostname="gb10-node-1",
|
||||
port=8001,
|
||||
vram_total_gb=80,
|
||||
vram_used_gb=0,
|
||||
compute_utilization=0.0,
|
||||
status=NodeStatus.AVAILABLE,
|
||||
models_loaded=["deepseek", "qwen72"]
|
||||
)
|
||||
|
||||
# GB10 Node 2
|
||||
self.nodes["gb10-node-2"] = GPUNode(
|
||||
node_id="gb10-node-2",
|
||||
hostname="gb10-node-2",
|
||||
port=8001,
|
||||
vram_total_gb=80,
|
||||
vram_used_gb=0,
|
||||
compute_utilization=0.0,
|
||||
status=NodeStatus.AVAILABLE,
|
||||
models_loaded=["phi4", "qwen-coder", "llama31", "granite-guardian"]
|
||||
)
|
||||
|
||||
def get_available_nodes(self) -> List[GPUNode]:
|
||||
"""Get list of available nodes"""
|
||||
return [node for node in self.nodes.values() if node.is_available]
|
||||
|
||||
def find_best_node(self, job: Job) -> Optional[GPUNode]:
|
||||
"""
|
||||
Find best node for a job based on availability and requirements
|
||||
|
||||
Args:
|
||||
job: Job to schedule
|
||||
|
||||
Returns:
|
||||
Best GPU node or None if unavailable
|
||||
"""
|
||||
available_nodes = self.get_available_nodes()
|
||||
|
||||
# Filter nodes that have required model loaded
|
||||
suitable_nodes = [
|
||||
node for node in available_nodes
|
||||
if job.model in node.models_loaded
|
||||
and node.vram_available_gb >= job.estimated_vram_gb
|
||||
]
|
||||
|
||||
if not suitable_nodes:
|
||||
return None
|
||||
|
||||
# Sort by compute utilization (prefer less loaded nodes)
|
||||
suitable_nodes.sort(key=lambda n: n.compute_utilization)
|
||||
|
||||
return suitable_nodes[0]
|
||||
|
||||
def should_parallelize(self, job: Job) -> bool:
|
||||
"""
|
||||
Determine if job should be parallelized across multiple nodes
|
||||
|
||||
Args:
|
||||
job: Job to evaluate
|
||||
|
||||
Returns:
|
||||
True if should parallelize
|
||||
"""
|
||||
available_nodes = self.get_available_nodes()
|
||||
|
||||
# Need at least 2 nodes for parallelization
|
||||
if len(available_nodes) < 2:
|
||||
return False
|
||||
|
||||
# Job explicitly requires parallel execution
|
||||
if job.requires_parallel:
|
||||
return True
|
||||
|
||||
# High priority jobs with multiple available nodes
|
||||
if job.priority >= 1 and len(available_nodes) >= 2:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_parallel_nodes(self, job: Job) -> List[GPUNode]:
|
||||
"""
|
||||
Get nodes for parallel execution
|
||||
|
||||
Args:
|
||||
job: Job to parallelize
|
||||
|
||||
Returns:
|
||||
List of nodes to use
|
||||
"""
|
||||
available_nodes = self.get_available_nodes()
|
||||
|
||||
# Filter nodes with required model and sufficient VRAM
|
||||
suitable_nodes = [
|
||||
node for node in available_nodes
|
||||
if job.model in node.models_loaded
|
||||
and node.vram_available_gb >= job.estimated_vram_gb
|
||||
]
|
||||
|
||||
# Return up to 2 nodes for parallel execution
|
||||
return suitable_nodes[:2]
|
||||
|
||||
async def schedule_job(self, job: Job) -> Dict[str, Any]:
|
||||
"""
|
||||
Schedule a job for execution
|
||||
|
||||
Args:
|
||||
job: Job to schedule
|
||||
|
||||
Returns:
|
||||
Scheduling decision with node assignments
|
||||
"""
|
||||
# Check if job should be parallelized
|
||||
if self.should_parallelize(job):
|
||||
nodes = self.get_parallel_nodes(job)
|
||||
if len(nodes) >= 2:
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"execution_mode": "parallel",
|
||||
"nodes": [
|
||||
{"node_id": node.node_id, "endpoint": f"http://{node.hostname}:{node.port}/{job.model}"}
|
||||
for node in nodes
|
||||
],
|
||||
"estimated_time": "distributed"
|
||||
}
|
||||
|
||||
# Serial execution on single node
|
||||
node = self.find_best_node(job)
|
||||
|
||||
if not node:
|
||||
# Add to queue if no nodes available
|
||||
self.job_queue.append(job)
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"execution_mode": "queued",
|
||||
"status": "waiting_for_resources",
|
||||
"queue_position": len(self.job_queue)
|
||||
}
|
||||
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"execution_mode": "serial" if job.requires_chaining else "single",
|
||||
"node": {
|
||||
"node_id": node.node_id,
|
||||
"endpoint": f"http://{node.hostname}:{node.port}/{job.model}"
|
||||
},
|
||||
"vram_allocated_gb": job.estimated_vram_gb,
|
||||
"estimated_time": "standard"
|
||||
}
|
||||
|
||||
def update_node_status(
|
||||
self,
|
||||
node_id: str,
|
||||
vram_used_gb: Optional[int] = None,
|
||||
compute_utilization: Optional[float] = None,
|
||||
status: Optional[NodeStatus] = None
|
||||
):
|
||||
"""
|
||||
Update node status metrics
|
||||
|
||||
Args:
|
||||
node_id: Node to update
|
||||
vram_used_gb: Current VRAM usage
|
||||
compute_utilization: Current compute utilization (0.0-1.0)
|
||||
status: Node status
|
||||
"""
|
||||
if node_id not in self.nodes:
|
||||
return
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if vram_used_gb is not None:
|
||||
node.vram_used_gb = vram_used_gb
|
||||
|
||||
if compute_utilization is not None:
|
||||
node.compute_utilization = compute_utilization
|
||||
|
||||
if status is not None:
|
||||
node.status = status
|
||||
|
||||
|
||||
def get_job_scheduler() -> JobScheduler:
|
||||
"""
|
||||
Factory function to create job scheduler
|
||||
|
||||
Returns:
|
||||
Configured JobScheduler instance
|
||||
"""
|
||||
return JobScheduler()
|
||||
211
backend/app/core/llm_pool.py
Normal file
211
backend/app/core/llm_pool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
LLM Pool Manager
|
||||
|
||||
Manages pool of LLM endpoints with OpenAI-compatible interface.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
import httpx
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMEndpoint:
|
||||
"""Represents an LLM endpoint"""
|
||||
model_name: str
|
||||
node_id: str
|
||||
base_url: str
|
||||
is_available: bool = True
|
||||
|
||||
@property
|
||||
def endpoint_url(self) -> str:
|
||||
"""Get full endpoint URL"""
|
||||
return f"{self.base_url}/{self.model_name}"
|
||||
|
||||
|
||||
class LLMPoolManager:
|
||||
"""
|
||||
Pool of LLM Endpoints
|
||||
|
||||
Each model is exposed via an OpenAI-compatible endpoint:
|
||||
- http://gb10-node-1:8001/deepseek
|
||||
- http://gb10-node-1:8001/qwen72
|
||||
- http://gb10-node-2:8001/phi4
|
||||
- http://gb10-node-2:8001/qwen-coder
|
||||
- http://gb10-node-2:8001/llama31
|
||||
- http://gb10-node-2:8001/granite-guardian
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize LLM pool"""
|
||||
self.endpoints: Dict[str, LLMEndpoint] = {}
|
||||
self._initialize_endpoints()
|
||||
|
||||
def _initialize_endpoints(self):
|
||||
"""Initialize all LLM endpoints"""
|
||||
# GB10 Node 1 endpoints
|
||||
self.endpoints["deepseek"] = LLMEndpoint(
|
||||
model_name="deepseek",
|
||||
node_id="gb10-node-1",
|
||||
base_url="http://gb10-node-1:8001"
|
||||
)
|
||||
|
||||
self.endpoints["qwen72"] = LLMEndpoint(
|
||||
model_name="qwen72",
|
||||
node_id="gb10-node-1",
|
||||
base_url="http://gb10-node-1:8001"
|
||||
)
|
||||
|
||||
# GB10 Node 2 endpoints
|
||||
self.endpoints["phi4"] = LLMEndpoint(
|
||||
model_name="phi4",
|
||||
node_id="gb10-node-2",
|
||||
base_url="http://gb10-node-2:8001"
|
||||
)
|
||||
|
||||
self.endpoints["qwen-coder"] = LLMEndpoint(
|
||||
model_name="qwen-coder",
|
||||
node_id="gb10-node-2",
|
||||
base_url="http://gb10-node-2:8001"
|
||||
)
|
||||
|
||||
self.endpoints["llama31"] = LLMEndpoint(
|
||||
model_name="llama31",
|
||||
node_id="gb10-node-2",
|
||||
base_url="http://gb10-node-2:8001"
|
||||
)
|
||||
|
||||
self.endpoints["granite-guardian"] = LLMEndpoint(
|
||||
model_name="granite-guardian",
|
||||
node_id="gb10-node-2",
|
||||
base_url="http://gb10-node-2:8001"
|
||||
)
|
||||
|
||||
def get_endpoint(self, model_name: str) -> Optional[LLMEndpoint]:
|
||||
"""
|
||||
Get endpoint for a specific model
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
|
||||
Returns:
|
||||
LLMEndpoint or None if not found
|
||||
"""
|
||||
return self.endpoints.get(model_name)
|
||||
|
||||
async def call_model(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call an LLM model via its endpoint
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
prompt: Input prompt
|
||||
parameters: Optional model parameters
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
endpoint = self.get_endpoint(model_name)
|
||||
|
||||
if not endpoint:
|
||||
return {
|
||||
"error": f"Model {model_name} not found",
|
||||
"available_models": list(self.endpoints.keys())
|
||||
}
|
||||
|
||||
if not endpoint.is_available:
|
||||
return {
|
||||
"error": f"Endpoint {model_name} is currently unavailable",
|
||||
"status": "offline"
|
||||
}
|
||||
|
||||
# Prepare OpenAI-compatible request
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": parameters.get("temperature", 0.7) if parameters else 0.7,
|
||||
"max_tokens": parameters.get("max_tokens", 2048) if parameters else 2048
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{endpoint.endpoint_url}/v1/chat/completions",
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPError as e:
|
||||
return {
|
||||
"error": f"Failed to call {model_name}",
|
||||
"details": str(e),
|
||||
"endpoint": endpoint.endpoint_url
|
||||
}
|
||||
|
||||
async def call_multiple_models(
|
||||
self,
|
||||
model_names: List[str],
|
||||
prompt: str,
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Call multiple models in parallel
|
||||
|
||||
Args:
|
||||
model_names: List of model names
|
||||
prompt: Input prompt
|
||||
parameters: Optional model parameters
|
||||
|
||||
Returns:
|
||||
Combined results from all models
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
tasks = [
|
||||
self.call_model(model, prompt, parameters)
|
||||
for model in model_names
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
return {
|
||||
"models": model_names,
|
||||
"results": [
|
||||
{"model": model, "response": result}
|
||||
for model, result in zip(model_names, results)
|
||||
]
|
||||
}
|
||||
|
||||
def list_available_models(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all available models
|
||||
|
||||
Returns:
|
||||
List of model information
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"model_name": endpoint.model_name,
|
||||
"node_id": endpoint.node_id,
|
||||
"endpoint_url": endpoint.endpoint_url,
|
||||
"is_available": endpoint.is_available
|
||||
}
|
||||
for endpoint in self.endpoints.values()
|
||||
]
|
||||
|
||||
|
||||
def get_llm_pool() -> LLMPoolManager:
|
||||
"""
|
||||
Factory function to create LLM pool manager
|
||||
|
||||
Returns:
|
||||
Configured LLMPoolManager instance
|
||||
"""
|
||||
return LLMPoolManager()
|
||||
187
backend/app/core/llm_router.py
Normal file
187
backend/app/core/llm_router.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
LLM Router Agent
|
||||
|
||||
Routes requests to appropriate LLM models based on task classification.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Types of tasks for LLM routing"""
|
||||
GENERAL_REASONING = "general_reasoning" # DeepSeek
|
||||
MULTILINGUAL = "multilingual" # Qwen / Aya
|
||||
STRUCTURED_PARSING = "structured_parsing" # Phi-4
|
||||
RULE_GENERATION = "rule_generation" # Qwen-Coder
|
||||
ADVERSARIAL_REASONING = "adversarial_reasoning" # LLaMA 3.1
|
||||
CLASSIFICATION = "classification" # Granite Guardian
|
||||
|
||||
|
||||
class LLMRouterAgent:
|
||||
"""
|
||||
Router Agent - Interprets incoming requests and routes to appropriate LLM
|
||||
|
||||
This agent classifies the incoming request and determines which specialized
|
||||
LLM should handle it based on the task type.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_config: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize router agent
|
||||
|
||||
Args:
|
||||
policy_config: Optional routing policy configuration
|
||||
"""
|
||||
self.policy_config = policy_config or {}
|
||||
self.routing_rules = self._initialize_routing_rules()
|
||||
|
||||
def _initialize_routing_rules(self) -> Dict[TaskType, Dict[str, Any]]:
|
||||
"""Initialize routing rules for each task type"""
|
||||
return {
|
||||
TaskType.GENERAL_REASONING: {
|
||||
"model": "deepseek",
|
||||
"endpoint": "deepseek",
|
||||
"priority": 1,
|
||||
"description": "General reasoning and complex analysis"
|
||||
},
|
||||
TaskType.MULTILINGUAL: {
|
||||
"model": "qwen72",
|
||||
"endpoint": "qwen72",
|
||||
"priority": 2,
|
||||
"description": "Multilingual translation and analysis"
|
||||
},
|
||||
TaskType.STRUCTURED_PARSING: {
|
||||
"model": "phi4",
|
||||
"endpoint": "phi4",
|
||||
"priority": 3,
|
||||
"description": "Structured data parsing and extraction"
|
||||
},
|
||||
TaskType.RULE_GENERATION: {
|
||||
"model": "qwen-coder",
|
||||
"endpoint": "qwen-coder",
|
||||
"priority": 2,
|
||||
"description": "Code and rule generation"
|
||||
},
|
||||
TaskType.ADVERSARIAL_REASONING: {
|
||||
"model": "llama31",
|
||||
"endpoint": "llama31",
|
||||
"priority": 1,
|
||||
"description": "Adversarial threat analysis"
|
||||
},
|
||||
TaskType.CLASSIFICATION: {
|
||||
"model": "granite-guardian",
|
||||
"endpoint": "granite-guardian",
|
||||
"priority": 4,
|
||||
"description": "Pure classification tasks"
|
||||
}
|
||||
}
|
||||
|
||||
def classify_request(self, request: Dict[str, Any]) -> TaskType:
|
||||
"""
|
||||
Classify incoming request to determine task type
|
||||
|
||||
Args:
|
||||
request: Request containing prompt and metadata
|
||||
|
||||
Returns:
|
||||
Classified task type
|
||||
"""
|
||||
prompt = request.get("prompt", "").lower()
|
||||
task_hints = request.get("task_hints", [])
|
||||
|
||||
# Classification logic based on keywords and hints
|
||||
if any(hint in task_hints for hint in ["translate", "multilingual", "language"]):
|
||||
return TaskType.MULTILINGUAL
|
||||
|
||||
if any(hint in task_hints for hint in ["parse", "extract", "structure"]):
|
||||
return TaskType.STRUCTURED_PARSING
|
||||
|
||||
if any(hint in task_hints for hint in ["code", "rule", "generate", "script"]):
|
||||
return TaskType.RULE_GENERATION
|
||||
|
||||
if any(hint in task_hints for hint in ["threat", "adversary", "attack", "malicious"]):
|
||||
return TaskType.ADVERSARIAL_REASONING
|
||||
|
||||
if any(hint in task_hints for hint in ["classify", "categorize", "label"]):
|
||||
return TaskType.CLASSIFICATION
|
||||
|
||||
# Default to general reasoning
|
||||
return TaskType.GENERAL_REASONING
|
||||
|
||||
def route_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Route request to appropriate LLM endpoint
|
||||
|
||||
Args:
|
||||
request: Request to route
|
||||
|
||||
Returns:
|
||||
Routing decision with endpoint and model info
|
||||
"""
|
||||
task_type = self.classify_request(request)
|
||||
routing_rule = self.routing_rules[task_type]
|
||||
|
||||
return {
|
||||
"task_type": task_type.value,
|
||||
"model": routing_rule["model"],
|
||||
"endpoint": routing_rule["endpoint"],
|
||||
"priority": routing_rule["priority"],
|
||||
"description": routing_rule["description"],
|
||||
"requires_parallel": request.get("requires_parallel", False),
|
||||
"requires_chaining": request.get("requires_chaining", False)
|
||||
}
|
||||
|
||||
def should_parallelize(self, request: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Determine if request should be parallelized across multiple nodes
|
||||
|
||||
Args:
|
||||
request: Request to evaluate
|
||||
|
||||
Returns:
|
||||
True if should be parallelized
|
||||
"""
|
||||
# Large batch requests
|
||||
if request.get("batch_size", 1) > 10:
|
||||
return True
|
||||
|
||||
# Explicit parallel flag
|
||||
if request.get("requires_parallel", False):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def requires_serial_chaining(self, request: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Determine if request requires serial reasoning (chained operations)
|
||||
|
||||
Args:
|
||||
request: Request to evaluate
|
||||
|
||||
Returns:
|
||||
True if requires chaining
|
||||
"""
|
||||
# Complex multi-step reasoning
|
||||
if request.get("requires_chaining", False):
|
||||
return True
|
||||
|
||||
# Multiple dependent operations
|
||||
if len(request.get("operations", [])) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_router(policy_config: Optional[Dict[str, Any]] = None) -> LLMRouterAgent:
|
||||
"""
|
||||
Factory function to create LLM router agent
|
||||
|
||||
Args:
|
||||
policy_config: Optional routing policy configuration
|
||||
|
||||
Returns:
|
||||
Configured LLMRouterAgent instance
|
||||
"""
|
||||
return LLMRouterAgent(policy_config)
|
||||
259
backend/app/core/merger_agent.py
Normal file
259
backend/app/core/merger_agent.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Merger Agent
|
||||
|
||||
Combines and synthesizes results from multiple LLM models.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MergeStrategy(Enum):
|
||||
"""Strategies for merging LLM results"""
|
||||
CONSENSUS = "consensus" # Take majority vote
|
||||
WEIGHTED = "weighted" # Weight by model confidence
|
||||
CONCATENATE = "concatenate" # Combine all outputs
|
||||
BEST_QUALITY = "best_quality" # Select highest quality response
|
||||
ENSEMBLE = "ensemble" # Ensemble multiple results
|
||||
|
||||
|
||||
class MergerAgent:
|
||||
"""
|
||||
Merger Agent - Combines results from multiple LLM executions
|
||||
|
||||
When multiple models process the same or related requests,
|
||||
this agent intelligently merges their outputs into a coherent response.
|
||||
"""
|
||||
|
||||
def __init__(self, default_strategy: MergeStrategy = MergeStrategy.CONSENSUS):
|
||||
"""
|
||||
Initialize merger agent
|
||||
|
||||
Args:
|
||||
default_strategy: Default merging strategy
|
||||
"""
|
||||
self.default_strategy = default_strategy
|
||||
|
||||
def merge_results(
|
||||
self,
|
||||
results: List[Dict[str, Any]],
|
||||
strategy: Optional[MergeStrategy] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge multiple LLM results using specified strategy
|
||||
|
||||
Args:
|
||||
results: List of results from different models
|
||||
strategy: Merging strategy (uses default if not specified)
|
||||
|
||||
Returns:
|
||||
Merged result
|
||||
"""
|
||||
if not results:
|
||||
return {"error": "No results to merge"}
|
||||
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
|
||||
merge_strategy = strategy or self.default_strategy
|
||||
|
||||
if merge_strategy == MergeStrategy.CONSENSUS:
|
||||
return self._merge_consensus(results)
|
||||
elif merge_strategy == MergeStrategy.WEIGHTED:
|
||||
return self._merge_weighted(results)
|
||||
elif merge_strategy == MergeStrategy.CONCATENATE:
|
||||
return self._merge_concatenate(results)
|
||||
elif merge_strategy == MergeStrategy.BEST_QUALITY:
|
||||
return self._merge_best_quality(results)
|
||||
elif merge_strategy == MergeStrategy.ENSEMBLE:
|
||||
return self._merge_ensemble(results)
|
||||
else:
|
||||
return self._merge_consensus(results)
|
||||
|
||||
def _merge_consensus(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge by consensus - take majority vote or most common response
|
||||
|
||||
Args:
|
||||
results: List of results
|
||||
|
||||
Returns:
|
||||
Consensus result
|
||||
"""
|
||||
# For classification tasks, take majority vote
|
||||
if all("classification" in r for r in results):
|
||||
classifications = [r["classification"] for r in results]
|
||||
most_common = max(set(classifications), key=classifications.count)
|
||||
|
||||
return {
|
||||
"strategy": "consensus",
|
||||
"result": most_common,
|
||||
"confidence": classifications.count(most_common) / len(classifications),
|
||||
"votes": dict((k, classifications.count(k)) for k in set(classifications))
|
||||
}
|
||||
|
||||
# For text generation, use first high-quality result
|
||||
valid_results = [r for r in results if "response" in r and r["response"]]
|
||||
if valid_results:
|
||||
return {
|
||||
"strategy": "consensus",
|
||||
"result": valid_results[0]["response"],
|
||||
"num_models": len(results)
|
||||
}
|
||||
|
||||
return {"strategy": "consensus", "result": None}
|
||||
|
||||
def _merge_weighted(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge by weighted average based on model confidence
|
||||
|
||||
Args:
|
||||
results: List of results with confidence scores
|
||||
|
||||
Returns:
|
||||
Weighted result
|
||||
"""
|
||||
# Extract results with confidence scores
|
||||
weighted_results = [
|
||||
(r.get("response", ""), r.get("confidence", 0.5))
|
||||
for r in results
|
||||
]
|
||||
|
||||
# Sort by confidence
|
||||
weighted_results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return {
|
||||
"strategy": "weighted",
|
||||
"result": weighted_results[0][0],
|
||||
"confidence": weighted_results[0][1],
|
||||
"all_results": [
|
||||
{"response": resp, "confidence": conf}
|
||||
for resp, conf in weighted_results
|
||||
]
|
||||
}
|
||||
|
||||
def _merge_concatenate(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Concatenate all results
|
||||
|
||||
Args:
|
||||
results: List of results
|
||||
|
||||
Returns:
|
||||
Concatenated result
|
||||
"""
|
||||
responses = [
|
||||
r.get("response", "") for r in results if r.get("response")
|
||||
]
|
||||
|
||||
return {
|
||||
"strategy": "concatenate",
|
||||
"result": "\n\n---\n\n".join(responses),
|
||||
"num_responses": len(responses)
|
||||
}
|
||||
|
||||
def _merge_best_quality(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Select the highest quality response
|
||||
|
||||
Args:
|
||||
results: List of results
|
||||
|
||||
Returns:
|
||||
Best quality result
|
||||
"""
|
||||
# Score responses by length and presence of key indicators
|
||||
scored_results = []
|
||||
|
||||
for r in results:
|
||||
response = r.get("response", "")
|
||||
if not response:
|
||||
continue
|
||||
|
||||
score = 0
|
||||
score += len(response) / 100 # Longer responses get higher score
|
||||
score += response.count(".") * 0.5 # More complete sentences
|
||||
score += response.count("\n") * 0.3 # Better formatting
|
||||
|
||||
scored_results.append((response, score, r))
|
||||
|
||||
if not scored_results:
|
||||
return {"strategy": "best_quality", "result": None}
|
||||
|
||||
scored_results.sort(key=lambda x: x[1], reverse=True)
|
||||
best_response, best_score, best_result = scored_results[0]
|
||||
|
||||
return {
|
||||
"strategy": "best_quality",
|
||||
"result": best_response,
|
||||
"quality_score": best_score,
|
||||
"model": best_result.get("model", "unknown")
|
||||
}
|
||||
|
||||
def _merge_ensemble(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Ensemble multiple results by combining their insights
|
||||
|
||||
Args:
|
||||
results: List of results
|
||||
|
||||
Returns:
|
||||
Ensemble result
|
||||
"""
|
||||
# Collect unique insights from all models
|
||||
all_responses = [r.get("response", "") for r in results if r.get("response")]
|
||||
|
||||
# Create ensemble summary
|
||||
ensemble_summary = {
|
||||
"strategy": "ensemble",
|
||||
"num_models": len(results),
|
||||
"individual_results": [
|
||||
{
|
||||
"model": r.get("model", "unknown"),
|
||||
"response": r.get("response", ""),
|
||||
"confidence": r.get("confidence", 0.5)
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
"synthesized_result": self._synthesize_insights(all_responses)
|
||||
}
|
||||
|
||||
return ensemble_summary
|
||||
|
||||
def _synthesize_insights(self, responses: List[str]) -> str:
|
||||
"""
|
||||
Synthesize insights from multiple responses
|
||||
|
||||
Args:
|
||||
responses: List of response strings
|
||||
|
||||
Returns:
|
||||
Synthesized summary
|
||||
"""
|
||||
if not responses:
|
||||
return ""
|
||||
|
||||
# Simple synthesis - in production, use another LLM to synthesize
|
||||
unique_points = []
|
||||
for response in responses:
|
||||
sentences = response.split(". ")
|
||||
for sentence in sentences:
|
||||
if sentence and sentence not in unique_points:
|
||||
unique_points.append(sentence)
|
||||
|
||||
return ". ".join(unique_points[:10]) # Top 10 insights
|
||||
|
||||
|
||||
def get_merger_agent(
|
||||
default_strategy: MergeStrategy = MergeStrategy.CONSENSUS
|
||||
) -> MergerAgent:
|
||||
"""
|
||||
Factory function to create merger agent
|
||||
|
||||
Args:
|
||||
default_strategy: Default merging strategy
|
||||
|
||||
Returns:
|
||||
Configured MergerAgent instance
|
||||
"""
|
||||
return MergerAgent(default_strategy)
|
||||
@@ -3,14 +3,14 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.routes import (
|
||||
auth, users, tenants, hosts, ingestion, vt, audit,
|
||||
notifications, velociraptor, playbooks, threat_intel, reports
|
||||
notifications, velociraptor, playbooks, threat_intel, reports, llm
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
description="Multi-tenant threat hunting companion for Velociraptor with ML-powered threat detection",
|
||||
version="1.0.0"
|
||||
description="Multi-tenant threat hunting companion for Velociraptor with ML-powered threat detection and distributed LLM routing",
|
||||
version="1.1.0"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
@@ -35,6 +35,7 @@ app.include_router(velociraptor.router, prefix="/api/velociraptor", tags=["Veloc
|
||||
app.include_router(playbooks.router, prefix="/api/playbooks", tags=["Playbooks"])
|
||||
app.include_router(threat_intel.router, prefix="/api/threat-intel", tags=["Threat Intelligence"])
|
||||
app.include_router(reports.router, prefix="/api/reports", tags=["Reports"])
|
||||
app.include_router(llm.router, prefix="/api/llm", tags=["Distributed LLM"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -42,7 +43,7 @@ async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"message": f"Welcome to {settings.app_name}",
|
||||
"version": "1.0.0",
|
||||
"version": "1.1.0",
|
||||
"docs": "/docs",
|
||||
"features": [
|
||||
"JWT Authentication with 2FA",
|
||||
@@ -52,7 +53,8 @@ async def root():
|
||||
"Velociraptor integration",
|
||||
"ML-powered threat detection",
|
||||
"Automated playbooks",
|
||||
"Advanced reporting"
|
||||
"Advanced reporting",
|
||||
"Distributed LLM routing (Phase 5)"
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
74
backend/app/schemas/llm.py
Normal file
74
backend/app/schemas/llm.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LLMRequestSchema(BaseModel):
|
||||
"""Schema for LLM processing request"""
|
||||
prompt: str
|
||||
task_hints: Optional[List[str]] = []
|
||||
requires_parallel: bool = False
|
||||
requires_chaining: bool = False
|
||||
batch_size: int = 1
|
||||
operations: Optional[List[str]] = []
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class RoutingDecision(BaseModel):
|
||||
"""Schema for routing decision"""
|
||||
task_type: str
|
||||
model: str
|
||||
endpoint: str
|
||||
priority: int
|
||||
description: str
|
||||
requires_parallel: bool
|
||||
requires_chaining: bool
|
||||
|
||||
|
||||
class NodeInfo(BaseModel):
|
||||
"""Schema for GPU node information"""
|
||||
node_id: str
|
||||
hostname: str
|
||||
vram_total_gb: int
|
||||
vram_used_gb: int
|
||||
vram_available_gb: int
|
||||
compute_utilization: float
|
||||
status: str
|
||||
models_loaded: List[str]
|
||||
|
||||
|
||||
class SchedulingDecision(BaseModel):
|
||||
"""Schema for job scheduling decision"""
|
||||
job_id: str
|
||||
execution_mode: str
|
||||
nodes: Optional[List[Dict[str, str]]] = None
|
||||
node: Optional[Dict[str, str]] = None
|
||||
status: Optional[str] = None
|
||||
queue_position: Optional[int] = None
|
||||
|
||||
|
||||
class LLMResponseSchema(BaseModel):
|
||||
"""Schema for LLM response"""
|
||||
job_id: str
|
||||
status: str
|
||||
routing: Optional[RoutingDecision] = None
|
||||
scheduling: Optional[SchedulingDecision] = None
|
||||
result: Any
|
||||
execution_mode: str
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Schema for model information"""
|
||||
model_name: str
|
||||
node_id: str
|
||||
endpoint_url: str
|
||||
is_available: bool
|
||||
|
||||
|
||||
class MergedResult(BaseModel):
|
||||
"""Schema for merged result"""
|
||||
strategy: str
|
||||
result: Any
|
||||
confidence: Optional[float] = None
|
||||
num_models: Optional[int] = None
|
||||
all_results: Optional[List[Dict[str, Any]]] = None
|
||||
Reference in New Issue
Block a user