Files
finance_bot/.history/app/security/middleware_20251210220328.py
2025-12-10 22:09:31 +09:00

308 lines
11 KiB
Python

"""
FastAPI Middleware Stack - Authentication, Authorization, and Security
"""
from fastapi import FastAPI, Request, HTTPException, status
from fastapi.responses import JSONResponse
import logging
import time
from typing import Optional, Callable, Any
from datetime import datetime
import redis
from starlette.middleware.base import BaseHTTPMiddleware
from app.security.jwt_manager import jwt_manager, TokenPayload
from app.security.hmac_manager import hmac_manager
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
from app.core.config import settings
logger = logging.getLogger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses"""
async def dispatch(self, request: Request, call_next: Callable) -> Any:
response = await call_next(request)
# Security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Content-Security-Policy"] = "default-src 'self'"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting using Redis"""
def __init__(self, app: FastAPI, redis_client: redis.Redis):
super().__init__(app)
self.redis_client = redis_client
self.rate_limit_requests = 100 # requests
self.rate_limit_window = 60 # seconds
async def dispatch(self, request: Request, call_next: Callable) -> Any:
# Skip rate limiting for health checks
if request.url.path == "/health":
return await call_next(request)
# Get client IP
client_ip = request.client.host
# Rate limit key
rate_key = f"rate_limit:{client_ip}"
# Check rate limit
try:
current = self.redis_client.get(rate_key)
if current and int(current) >= self.rate_limit_requests:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)
# Increment counter
pipe = self.redis_client.pipeline()
pipe.incr(rate_key)
pipe.expire(rate_key, self.rate_limit_window)
pipe.execute()
except Exception as e:
logger.warning(f"Rate limiting error: {e}")
return await call_next(request)
class HMACVerificationMiddleware(BaseHTTPMiddleware):
"""HMAC signature verification and anti-replay protection"""
def __init__(self, app: FastAPI, redis_client: redis.Redis):
super().__init__(app)
self.redis_client = redis_client
hmac_manager.redis_client = redis_client
async def dispatch(self, request: Request, call_next: Callable) -> Any:
# Skip verification for public endpoints
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
if request.url.path in public_paths:
return await call_next(request)
# Extract HMAC headers
signature = request.headers.get("X-Signature")
timestamp = request.headers.get("X-Timestamp")
client_id = request.headers.get("X-Client-Id", "unknown")
# HMAC verification is optional in MVP (configurable)
if settings.require_hmac_verification:
if not signature or not timestamp:
return JSONResponse(
status_code=400,
content={"detail": "Missing HMAC headers"}
)
try:
timestamp_int = int(timestamp)
except ValueError:
return JSONResponse(
status_code=400,
content={"detail": "Invalid timestamp format"}
)
# Read body for signature verification
body = await request.body()
body_dict = {}
if body:
try:
import json
body_dict = json.loads(body)
except:
pass
# Verify HMAC
# Get client secret (hardcoded for MVP, should be from DB)
client_secret = settings.hmac_secret_key
is_valid, error_msg = hmac_manager.verify_signature(
method=request.method,
endpoint=request.url.path,
timestamp=timestamp_int,
signature=signature,
body=body_dict,
client_secret=client_secret,
)
if not is_valid:
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
return JSONResponse(
status_code=401,
content={"detail": f"HMAC verification failed: {error_msg}"}
)
# Store in request state for logging
request.state.client_id = client_id
request.state.timestamp = timestamp
return await call_next(request)
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
"""JWT token verification and extraction"""
async def dispatch(self, request: Request, call_next: Callable) -> Any:
# Skip auth for public endpoints
if request.url.path in ["/health", "/docs", "/openapi.json", "/api/v1/auth/login"]:
return await call_next(request)
# Extract token from Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
return JSONResponse(
status_code=401,
content={"detail": "Missing Authorization header"}
)
# Parse "Bearer <token>"
try:
scheme, token = auth_header.split()
if scheme.lower() != "bearer":
raise ValueError("Invalid auth scheme")
except (ValueError, IndexError):
return JSONResponse(
status_code=401,
content={"detail": "Invalid Authorization header format"}
)
# Verify JWT
try:
token_payload = jwt_manager.verify_token(token)
except ValueError as e:
logger.warning(f"JWT verification failed: {e}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or expired token"}
)
# Store in request state
request.state.user_id = token_payload.sub
request.state.token_type = token_payload.type
request.state.device_id = token_payload.device_id
request.state.family_ids = token_payload.family_ids
return await call_next(request)
class RBACMiddleware(BaseHTTPMiddleware):
"""Role-Based Access Control enforcement"""
def __init__(self, app: FastAPI, db_session: Any):
super().__init__(app)
self.db_session = db_session
async def dispatch(self, request: Request, call_next: Callable) -> Any:
# Skip RBAC for public endpoints
if request.url.path in ["/health", "/docs", "/openapi.json"]:
return await call_next(request)
# Get user context from JWT
user_id = getattr(request.state, "user_id", None)
family_ids = getattr(request.state, "family_ids", [])
if not user_id:
# Already handled by JWTAuthenticationMiddleware
return await call_next(request)
# Extract family_id from URL or body
family_id = self._extract_family_id(request)
if family_id and family_id not in family_ids:
return JSONResponse(
status_code=403,
content={"detail": "Access denied to this family"}
)
# Load user role (would need DB query in production)
# For MVP: Store in request state, resolved in endpoint handlers
request.state.family_id = family_id
return await call_next(request)
@staticmethod
def _extract_family_id(request: Request) -> Optional[int]:
"""Extract family_id from URL or request body"""
# From URL path: /api/v1/families/{family_id}/...
if "{family_id}" in request.url.path:
# Parse from actual path
parts = request.url.path.split("/")
for i, part in enumerate(parts):
if part == "families" and i + 1 < len(parts):
try:
return int(parts[i + 1])
except ValueError:
pass
return None
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all requests and responses for audit"""
async def dispatch(self, request: Request, call_next: Callable) -> Any:
# Start timer
start_time = time.time()
# Get client info
client_ip = request.client.host if request.client else "unknown"
user_id = getattr(request.state, "user_id", None)
# Process request
try:
response = await call_next(request)
response_time_ms = int((time.time() - start_time) * 1000)
# Log successful request
logger.info(
f"Endpoint={request.url.path} "
f"Method={request.method} "
f"Status={response.status_code} "
f"Time={response_time_ms}ms "
f"User={user_id} "
f"IP={client_ip}"
)
# Add timing header
response.headers["X-Response-Time"] = str(response_time_ms)
return response
except Exception as e:
response_time_ms = int((time.time() - start_time) * 1000)
logger.error(
f"Request error - Endpoint={request.url.path} "
f"Error={str(e)} "
f"Time={response_time_ms}ms "
f"User={user_id} "
f"IP={client_ip}"
)
raise
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
"""Register all security middleware in correct order"""
# Order matters! Process in reverse order of registration:
# 1. RequestLoggingMiddleware (innermost, executes last)
# 2. RBACMiddleware
# 3. JWTAuthenticationMiddleware
# 4. HMACVerificationMiddleware
# 5. RateLimitMiddleware
# 6. SecurityHeadersMiddleware (outermost, executes first)
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(RBACMiddleware, db_session=db_session)
app.add_middleware(JWTAuthenticationMiddleware)
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
app.add_middleware(SecurityHeadersMiddleware)