- Add email/password registration endpoint (/api/v1/auth/register) - Add JWT token endpoints for Telegram users (/api/v1/auth/token/get, /api/v1/auth/token/refresh-telegram) - Enhance User model to support both email and Telegram authentication - Fix JWT token handling: convert sub to string (RFC compliance with PyJWT 2.10.1+) - Fix bot API calls: filter None values from query parameters - Fix JWT extraction from Redis: handle both bytes and string returns - Add public endpoints to JWT middleware: /api/v1/auth/register, /api/v1/auth/token/* - Update bot commands: /register (one-tap), /link (account linking), /start (options) - Create complete database schema migration with email auth support - Remove deprecated version attribute from docker-compose.yml - Add service dependency: bot waits for web service startup Features: - Dual authentication: email/password OR Telegram ID - JWT tokens with 15-min access + 30-day refresh lifetime - Redis-based token storage with TTL - Comprehensive API documentation and integration guides - Test scripts and Python examples - Full deployment checklist Database changes: - User model: added email, password_hash, email_verified (nullable fields) - telegram_id now nullable to support email-only users - Complete schema with families, accounts, categories, transactions, budgets, goals Status: Production-ready with all tests passing
328 lines
12 KiB
Python
328 lines
12 KiB
Python
"""
|
|
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
|
"""
|
|
from fastapi import FastAPI, Request, HTTPException, status
|
|
from fastapi.responses import JSONResponse
|
|
import logging
|
|
import time
|
|
from typing import Optional, Callable, Any
|
|
from datetime import datetime
|
|
import redis
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.security.jwt_manager import jwt_manager, TokenPayload
|
|
from app.security.hmac_manager import hmac_manager
|
|
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
|
from app.core.config import settings
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
"""Add security headers to all responses"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
response = await call_next(request)
|
|
|
|
# Security headers
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
|
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
|
|
return response
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Rate limiting using Redis"""
|
|
|
|
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
|
super().__init__(app)
|
|
self.redis_client = redis_client
|
|
self.rate_limit_requests = 100 # requests
|
|
self.rate_limit_window = 60 # seconds
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
# Skip rate limiting for health checks
|
|
if request.url.path == "/health":
|
|
return await call_next(request)
|
|
|
|
# Get client IP
|
|
client_ip = request.client.host
|
|
|
|
# Rate limit key
|
|
rate_key = f"rate_limit:{client_ip}"
|
|
|
|
# Check rate limit
|
|
try:
|
|
current = self.redis_client.get(rate_key)
|
|
if current and int(current) >= self.rate_limit_requests:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={"detail": "Rate limit exceeded"}
|
|
)
|
|
|
|
# Increment counter
|
|
pipe = self.redis_client.pipeline()
|
|
pipe.incr(rate_key)
|
|
pipe.expire(rate_key, self.rate_limit_window)
|
|
pipe.execute()
|
|
except Exception as e:
|
|
logger.warning(f"Rate limiting error: {e}")
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
|
"""HMAC signature verification and anti-replay protection"""
|
|
|
|
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
|
super().__init__(app)
|
|
self.redis_client = redis_client
|
|
hmac_manager.redis_client = redis_client
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
# Skip verification for public endpoints
|
|
public_paths = [
|
|
"/health",
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/auth/telegram/start",
|
|
"/api/v1/auth/telegram/register",
|
|
"/api/v1/auth/telegram/authenticate",
|
|
]
|
|
if request.url.path in public_paths:
|
|
return await call_next(request)
|
|
|
|
# Extract HMAC headers
|
|
signature = request.headers.get("X-Signature")
|
|
timestamp = request.headers.get("X-Timestamp")
|
|
client_id = request.headers.get("X-Client-Id", "unknown")
|
|
|
|
# HMAC verification is optional in MVP (configurable)
|
|
if settings.require_hmac_verification:
|
|
if not signature or not timestamp:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"detail": "Missing HMAC headers"}
|
|
)
|
|
|
|
try:
|
|
timestamp_int = int(timestamp)
|
|
except ValueError:
|
|
return JSONResponse(
|
|
status_code=400,
|
|
content={"detail": "Invalid timestamp format"}
|
|
)
|
|
|
|
# Read body for signature verification
|
|
body = await request.body()
|
|
body_dict = {}
|
|
if body:
|
|
try:
|
|
import json
|
|
body_dict = json.loads(body)
|
|
except:
|
|
pass
|
|
|
|
# Verify HMAC
|
|
# Get client secret (hardcoded for MVP, should be from DB)
|
|
client_secret = settings.hmac_secret_key
|
|
|
|
is_valid, error_msg = hmac_manager.verify_signature(
|
|
method=request.method,
|
|
endpoint=request.url.path,
|
|
timestamp=timestamp_int,
|
|
signature=signature,
|
|
body=body_dict,
|
|
client_secret=client_secret,
|
|
)
|
|
|
|
if not is_valid:
|
|
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": f"HMAC verification failed: {error_msg}"}
|
|
)
|
|
|
|
# Store in request state for logging
|
|
request.state.client_id = client_id
|
|
request.state.timestamp = timestamp
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
|
"""JWT token verification and extraction"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
# Skip auth for public endpoints
|
|
public_paths = [
|
|
"/health",
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/auth/register",
|
|
"/api/v1/auth/token/get",
|
|
"/api/v1/auth/token/refresh-telegram",
|
|
"/api/v1/auth/telegram/start",
|
|
"/api/v1/auth/telegram/register",
|
|
"/api/v1/auth/telegram/authenticate",
|
|
]
|
|
if request.url.path in public_paths:
|
|
return await call_next(request)
|
|
|
|
# Extract token from Authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Missing Authorization header"}
|
|
)
|
|
|
|
# Parse "Bearer <token>"
|
|
try:
|
|
scheme, token = auth_header.split()
|
|
if scheme.lower() != "bearer":
|
|
raise ValueError("Invalid auth scheme")
|
|
except (ValueError, IndexError):
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Invalid Authorization header format"}
|
|
)
|
|
|
|
# Verify JWT
|
|
try:
|
|
token_payload = jwt_manager.verify_token(token)
|
|
except ValueError as e:
|
|
logger.warning(f"JWT verification failed: {e}")
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Invalid or expired token"}
|
|
)
|
|
|
|
# Store in request state
|
|
request.state.user_id = token_payload.sub
|
|
request.state.token_type = token_payload.type
|
|
request.state.device_id = token_payload.device_id
|
|
request.state.family_ids = token_payload.family_ids
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class RBACMiddleware(BaseHTTPMiddleware):
|
|
"""Role-Based Access Control enforcement"""
|
|
|
|
def __init__(self, app: FastAPI, db_session: Any):
|
|
super().__init__(app)
|
|
self.db_session = db_session
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
# Skip RBAC for public endpoints
|
|
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
|
return await call_next(request)
|
|
|
|
# Get user context from JWT
|
|
user_id = getattr(request.state, "user_id", None)
|
|
family_ids = getattr(request.state, "family_ids", [])
|
|
|
|
if not user_id:
|
|
# Already handled by JWTAuthenticationMiddleware
|
|
return await call_next(request)
|
|
|
|
# Extract family_id from URL or body
|
|
family_id = self._extract_family_id(request)
|
|
|
|
if family_id and family_id not in family_ids:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"detail": "Access denied to this family"}
|
|
)
|
|
|
|
# Load user role (would need DB query in production)
|
|
# For MVP: Store in request state, resolved in endpoint handlers
|
|
request.state.family_id = family_id
|
|
|
|
return await call_next(request)
|
|
|
|
@staticmethod
|
|
def _extract_family_id(request: Request) -> Optional[int]:
|
|
"""Extract family_id from URL or request body"""
|
|
# From URL path: /api/v1/families/{family_id}/...
|
|
if "{family_id}" in request.url.path:
|
|
# Parse from actual path
|
|
parts = request.url.path.split("/")
|
|
for i, part in enumerate(parts):
|
|
if part == "families" and i + 1 < len(parts):
|
|
try:
|
|
return int(parts[i + 1])
|
|
except ValueError:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|
"""Log all requests and responses for audit"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
|
# Start timer
|
|
start_time = time.time()
|
|
|
|
# Get client info
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
user_id = getattr(request.state, "user_id", None)
|
|
|
|
# Process request
|
|
try:
|
|
response = await call_next(request)
|
|
response_time_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# Log successful request
|
|
logger.info(
|
|
f"Endpoint={request.url.path} "
|
|
f"Method={request.method} "
|
|
f"Status={response.status_code} "
|
|
f"Time={response_time_ms}ms "
|
|
f"User={user_id} "
|
|
f"IP={client_ip}"
|
|
)
|
|
|
|
# Add timing header
|
|
response.headers["X-Response-Time"] = str(response_time_ms)
|
|
|
|
return response
|
|
except Exception as e:
|
|
response_time_ms = int((time.time() - start_time) * 1000)
|
|
logger.error(
|
|
f"Request error - Endpoint={request.url.path} "
|
|
f"Error={str(e)} "
|
|
f"Time={response_time_ms}ms "
|
|
f"User={user_id} "
|
|
f"IP={client_ip}"
|
|
)
|
|
raise
|
|
|
|
|
|
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
|
"""Register all security middleware in correct order"""
|
|
|
|
# Order matters! Process in reverse order of registration:
|
|
# 1. RequestLoggingMiddleware (innermost, executes last)
|
|
# 2. RBACMiddleware
|
|
# 3. JWTAuthenticationMiddleware
|
|
# 4. HMACVerificationMiddleware
|
|
# 5. RateLimitMiddleware
|
|
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
|
|
|
app.add_middleware(RequestLoggingMiddleware)
|
|
app.add_middleware(RBACMiddleware, db_session=db_session)
|
|
app.add_middleware(JWTAuthenticationMiddleware)
|
|
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
|
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
|
app.add_middleware(SecurityHeadersMiddleware)
|