mirror of
https://github.com/mblanke/ThreatHunt.git
synced 2026-03-01 14:00:20 -05:00
- NetworkMap: hunt-scoped force-directed graph with click-to-inspect popover - NetworkMap: zoom/pan (wheel, drag, buttons), viewport transform - NetworkMap: clickable IP/Host/Domain/URL legend chips to filter node types - NetworkMap: brighter colors, 20% smaller nodes - DatasetViewer: IOC columns highlighted with colored headers + cell tinting - AUPScanner: hunt dropdown replacing dataset checkboxes, auto-select all - Rename 'Social Media (Personal)' theme to 'Social Media' with DB migration - Fix /api/hunts timeout: Dataset.rows lazy='noload' (was selectin cascade) - Add OS column mapping to normalizer - Full backend services, DB models, alembic migrations, new routes - New components: Dashboard, HuntManager, FileUpload, NetworkMap, etc. - Docker Compose deployment with nginx reverse proxy
198 lines
5.4 KiB
Python
198 lines
5.4 KiB
Python
"""API routes for authentication — register, login, refresh, profile."""
|
|
|
|
import logging
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import BaseModel, Field, EmailStr
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db import get_db
|
|
from app.db.models import User
|
|
from app.services.auth import (
|
|
hash_password,
|
|
verify_password,
|
|
create_token_pair,
|
|
decode_token,
|
|
get_current_user,
|
|
TokenPair,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
|
|
|
|
|
# ── Request / Response models ─────────────────────────────────────────
|
|
|
|
|
|
class RegisterRequest(BaseModel):
|
|
username: str = Field(..., min_length=3, max_length=64)
|
|
email: str = Field(..., max_length=256)
|
|
password: str = Field(..., min_length=8, max_length=128)
|
|
display_name: str | None = None
|
|
|
|
|
|
class LoginRequest(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
|
|
class RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
class UserResponse(BaseModel):
|
|
id: str
|
|
username: str
|
|
email: str
|
|
display_name: str | None
|
|
role: str
|
|
is_active: bool
|
|
created_at: str
|
|
|
|
|
|
class AuthResponse(BaseModel):
|
|
user: UserResponse
|
|
tokens: TokenPair
|
|
|
|
|
|
# ── Routes ────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post(
|
|
"/register",
|
|
response_model=AuthResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
summary="Register a new user",
|
|
)
|
|
async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
|
# Check for existing username
|
|
result = await db.execute(select(User).where(User.username == body.username))
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Username already taken",
|
|
)
|
|
|
|
# Check for existing email
|
|
result = await db.execute(select(User).where(User.email == body.email))
|
|
if result.scalar_one_or_none():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Email already registered",
|
|
)
|
|
|
|
user = User(
|
|
username=body.username,
|
|
email=body.email,
|
|
password_hash=hash_password(body.password),
|
|
display_name=body.display_name or body.username,
|
|
role="analyst", # Default role
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
tokens = create_token_pair(user.id, user.role)
|
|
|
|
logger.info(f"New user registered: {user.username} ({user.id})")
|
|
|
|
return AuthResponse(
|
|
user=UserResponse(
|
|
id=user.id,
|
|
username=user.username,
|
|
email=user.email,
|
|
display_name=user.display_name,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
created_at=user.created_at.isoformat(),
|
|
),
|
|
tokens=tokens,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/login",
|
|
response_model=AuthResponse,
|
|
summary="Login with username and password",
|
|
)
|
|
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(User).where(User.username == body.username))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not user.password_hash:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid username or password",
|
|
)
|
|
|
|
if not verify_password(body.password, user.password_hash):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid username or password",
|
|
)
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Account is disabled",
|
|
)
|
|
|
|
tokens = create_token_pair(user.id, user.role)
|
|
|
|
return AuthResponse(
|
|
user=UserResponse(
|
|
id=user.id,
|
|
username=user.username,
|
|
email=user.email,
|
|
display_name=user.display_name,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
created_at=user.created_at.isoformat(),
|
|
),
|
|
tokens=tokens,
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/refresh",
|
|
response_model=TokenPair,
|
|
summary="Refresh access token",
|
|
)
|
|
async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)):
|
|
token_data = decode_token(body.refresh_token)
|
|
|
|
if token_data.type != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type — use refresh token",
|
|
)
|
|
|
|
result = await db.execute(select(User).where(User.id == token_data.sub))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user or not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid user",
|
|
)
|
|
|
|
return create_token_pair(user.id, user.role)
|
|
|
|
|
|
@router.get(
|
|
"/me",
|
|
response_model=UserResponse,
|
|
summary="Get current user profile",
|
|
)
|
|
async def get_profile(user: User = Depends(get_current_user)):
|
|
return UserResponse(
|
|
id=user.id,
|
|
username=user.username,
|
|
email=user.email,
|
|
display_name=user.display_name,
|
|
role=user.role,
|
|
is_active=user.is_active,
|
|
created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at),
|
|
)
|