""" FastAPI Middleware Stack - Authentication, Authorization, and Security """ from fastapi import FastAPI, Request, HTTPException, status from fastapi.responses import JSONResponse import logging import time from typing import Optional, Callable, Any from datetime import datetime import redis from starlette.middleware.base import BaseHTTPMiddleware from app.security.jwt_manager import jwt_manager, TokenPayload from app.security.hmac_manager import hmac_manager from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission from app.core.config import settings logger = logging.getLogger(__name__) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses""" async def dispatch(self, request: Request, call_next: Callable) -> Any: response = await call_next(request) # Security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Content-Security-Policy"] = "default-src 'self'" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" return response class RateLimitMiddleware(BaseHTTPMiddleware): """Rate limiting using Redis""" def __init__(self, app: FastAPI, redis_client: redis.Redis): super().__init__(app) self.redis_client = redis_client self.rate_limit_requests = 100 # requests self.rate_limit_window = 60 # seconds async def dispatch(self, request: Request, call_next: Callable) -> Any: # Skip rate limiting for health checks if request.url.path == "/health": return await call_next(request) # Get client IP client_ip = request.client.host # Rate limit key rate_key = f"rate_limit:{client_ip}" # Check rate limit try: current = self.redis_client.get(rate_key) if current and int(current) >= self.rate_limit_requests: return JSONResponse( status_code=429, content={"detail": "Rate limit exceeded"} ) # Increment counter pipe = self.redis_client.pipeline() pipe.incr(rate_key) pipe.expire(rate_key, self.rate_limit_window) pipe.execute() except Exception as e: logger.warning(f"Rate limiting error: {e}") return await call_next(request) class HMACVerificationMiddleware(BaseHTTPMiddleware): """HMAC signature verification and anti-replay protection""" def __init__(self, app: FastAPI, redis_client: redis.Redis): super().__init__(app) self.redis_client = redis_client hmac_manager.redis_client = redis_client async def dispatch(self, request: Request, call_next: Callable) -> Any: # Skip verification for public endpoints public_paths = [ "/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start", "/api/v1/auth/telegram/register", "/api/v1/auth/telegram/authenticate", ] if request.url.path in public_paths: return await call_next(request) # Extract HMAC headers signature = request.headers.get("X-Signature") timestamp = request.headers.get("X-Timestamp") client_id = request.headers.get("X-Client-Id", "unknown") # HMAC verification is optional in MVP (configurable) if settings.require_hmac_verification: if not signature or not timestamp: return JSONResponse( status_code=400, content={"detail": "Missing HMAC headers"} ) try: timestamp_int = int(timestamp) except ValueError: return JSONResponse( status_code=400, content={"detail": "Invalid timestamp format"} ) # Read body for signature verification body = await request.body() body_dict = {} if body: try: import json body_dict = json.loads(body) except: pass # Verify HMAC # Get client secret (hardcoded for MVP, should be from DB) client_secret = settings.hmac_secret_key is_valid, error_msg = hmac_manager.verify_signature( method=request.method, endpoint=request.url.path, timestamp=timestamp_int, signature=signature, body=body_dict, client_secret=client_secret, ) if not is_valid: logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})") return JSONResponse( status_code=401, content={"detail": f"HMAC verification failed: {error_msg}"} ) # Store in request state for logging request.state.client_id = client_id request.state.timestamp = timestamp return await call_next(request) class JWTAuthenticationMiddleware(BaseHTTPMiddleware): """JWT token verification and extraction""" async def dispatch(self, request: Request, call_next: Callable) -> Any: # Skip auth for public endpoints public_paths = [ "/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start", "/api/v1/auth/telegram/register", "/api/v1/auth/telegram/authenticate", ] if request.url.path in public_paths: return await call_next(request) # Extract token from Authorization header auth_header = request.headers.get("Authorization") if not auth_header: return JSONResponse( status_code=401, content={"detail": "Missing Authorization header"} ) # Parse "Bearer " try: scheme, token = auth_header.split() if scheme.lower() != "bearer": raise ValueError("Invalid auth scheme") except (ValueError, IndexError): return JSONResponse( status_code=401, content={"detail": "Invalid Authorization header format"} ) # Verify JWT try: token_payload = jwt_manager.verify_token(token) except ValueError as e: logger.warning(f"JWT verification failed: {e}") return JSONResponse( status_code=401, content={"detail": "Invalid or expired token"} ) # Store in request state request.state.user_id = token_payload.sub request.state.token_type = token_payload.type request.state.device_id = token_payload.device_id request.state.family_ids = token_payload.family_ids return await call_next(request) class RBACMiddleware(BaseHTTPMiddleware): """Role-Based Access Control enforcement""" def __init__(self, app: FastAPI, db_session: Any): super().__init__(app) self.db_session = db_session async def dispatch(self, request: Request, call_next: Callable) -> Any: # Skip RBAC for public endpoints if request.url.path in ["/health", "/docs", "/openapi.json"]: return await call_next(request) # Get user context from JWT user_id = getattr(request.state, "user_id", None) family_ids = getattr(request.state, "family_ids", []) if not user_id: # Already handled by JWTAuthenticationMiddleware return await call_next(request) # Extract family_id from URL or body family_id = self._extract_family_id(request) if family_id and family_id not in family_ids: return JSONResponse( status_code=403, content={"detail": "Access denied to this family"} ) # Load user role (would need DB query in production) # For MVP: Store in request state, resolved in endpoint handlers request.state.family_id = family_id return await call_next(request) @staticmethod def _extract_family_id(request: Request) -> Optional[int]: """Extract family_id from URL or request body""" # From URL path: /api/v1/families/{family_id}/... if "{family_id}" in request.url.path: # Parse from actual path parts = request.url.path.split("/") for i, part in enumerate(parts): if part == "families" and i + 1 < len(parts): try: return int(parts[i + 1]) except ValueError: pass return None class RequestLoggingMiddleware(BaseHTTPMiddleware): """Log all requests and responses for audit""" async def dispatch(self, request: Request, call_next: Callable) -> Any: # Start timer start_time = time.time() # Get client info client_ip = request.client.host if request.client else "unknown" user_id = getattr(request.state, "user_id", None) # Process request try: response = await call_next(request) response_time_ms = int((time.time() - start_time) * 1000) # Log successful request logger.info( f"Endpoint={request.url.path} " f"Method={request.method} " f"Status={response.status_code} " f"Time={response_time_ms}ms " f"User={user_id} " f"IP={client_ip}" ) # Add timing header response.headers["X-Response-Time"] = str(response_time_ms) return response except Exception as e: response_time_ms = int((time.time() - start_time) * 1000) logger.error( f"Request error - Endpoint={request.url.path} " f"Error={str(e)} " f"Time={response_time_ms}ms " f"User={user_id} " f"IP={client_ip}" ) raise def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any): """Register all security middleware in correct order""" # Order matters! Process in reverse order of registration: # 1. RequestLoggingMiddleware (innermost, executes last) # 2. RBACMiddleware # 3. JWTAuthenticationMiddleware # 4. HMACVerificationMiddleware # 5. RateLimitMiddleware # 6. SecurityHeadersMiddleware (outermost, executes first) app.add_middleware(RequestLoggingMiddleware) app.add_middleware(RBACMiddleware, db_session=db_session) app.add_middleware(JWTAuthenticationMiddleware) app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client) app.add_middleware(RateLimitMiddleware, redis_client=redis_client) app.add_middleware(SecurityHeadersMiddleware)