"""API routes for authentication — register, login, refresh, profile.""" import logging from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field, EmailStr from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.db import get_db from app.db.models import User from app.services.auth import ( hash_password, verify_password, create_token_pair, decode_token, get_current_user, TokenPair, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/auth", tags=["auth"]) # ── Request / Response models ───────────────────────────────────────── class RegisterRequest(BaseModel): username: str = Field(..., min_length=3, max_length=64) email: str = Field(..., max_length=256) password: str = Field(..., min_length=8, max_length=128) display_name: str | None = None class LoginRequest(BaseModel): username: str password: str class RefreshRequest(BaseModel): refresh_token: str class UserResponse(BaseModel): id: str username: str email: str display_name: str | None role: str is_active: bool created_at: str class AuthResponse(BaseModel): user: UserResponse tokens: TokenPair # ── Routes ──────────────────────────────────────────────────────────── @router.post( "/register", response_model=AuthResponse, status_code=status.HTTP_201_CREATED, summary="Register a new user", ) async def register(body: RegisterRequest, db: AsyncSession = Depends(get_db)): # Check for existing username result = await db.execute(select(User).where(User.username == body.username)) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Username already taken", ) # Check for existing email result = await db.execute(select(User).where(User.email == body.email)) if result.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Email already registered", ) user = User( username=body.username, email=body.email, hashed_password=hash_password(body.password), display_name=body.display_name or body.username, role="analyst", # Default role ) db.add(user) await db.flush() tokens = create_token_pair(user.id, user.role) logger.info(f"New user registered: {user.username} ({user.id})") return AuthResponse( user=UserResponse( id=user.id, username=user.username, email=user.email, display_name=user.display_name, role=user.role, is_active=user.is_active, created_at=user.created_at.isoformat(), ), tokens=tokens, ) @router.post( "/login", response_model=AuthResponse, summary="Login with username and password", ) async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)): result = await db.execute(select(User).where(User.username == body.username)) user = result.scalar_one_or_none() if not user or not user.hashed_password: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", ) if not verify_password(body.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", ) if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Account is disabled", ) tokens = create_token_pair(user.id, user.role) return AuthResponse( user=UserResponse( id=user.id, username=user.username, email=user.email, display_name=user.display_name, role=user.role, is_active=user.is_active, created_at=user.created_at.isoformat(), ), tokens=tokens, ) @router.post( "/refresh", response_model=TokenPair, summary="Refresh access token", ) async def refresh_token(body: RefreshRequest, db: AsyncSession = Depends(get_db)): token_data = decode_token(body.refresh_token) if token_data.type != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type — use refresh token", ) result = await db.execute(select(User).where(User.id == token_data.sub)) user = result.scalar_one_or_none() if not user or not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid user", ) return create_token_pair(user.id, user.role) @router.get( "/me", response_model=UserResponse, summary="Get current user profile", ) async def get_profile(user: User = Depends(get_current_user)): return UserResponse( id=user.id, username=user.username, email=user.email, display_name=user.display_name, role=user.role, is_active=user.is_active, created_at=user.created_at.isoformat() if hasattr(user.created_at, 'isoformat') else str(user.created_at), )