aut flow
This commit is contained in:
316
.history/app/security/middleware_20251210221744.py
Normal file
316
.history/app/security/middleware_20251210221744.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
||||
"""
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Callable, Any
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.security.jwt_manager import jwt_manager, TokenPayload
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
response = await call_next(request)
|
||||
|
||||
# Security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting using Redis"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
self.rate_limit_requests = 100 # requests
|
||||
self.rate_limit_window = 60 # seconds
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
|
||||
# Rate limit key
|
||||
rate_key = f"rate_limit:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
current = self.redis_client.get(rate_key)
|
||||
if current and int(current) >= self.rate_limit_requests:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, self.rate_limit_window)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limiting error: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
||||
"""HMAC signature verification and anti-replay protection"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
hmac_manager.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip verification for public endpoints
|
||||
public_paths = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/telegram/start",
|
||||
"/api/v1/auth/telegram/register",
|
||||
"/api/v1/auth/telegram/authenticate",
|
||||
]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract HMAC headers
|
||||
signature = request.headers.get("X-Signature")
|
||||
timestamp = request.headers.get("X-Timestamp")
|
||||
client_id = request.headers.get("X-Client-Id", "unknown")
|
||||
|
||||
# HMAC verification is optional in MVP (configurable)
|
||||
if settings.require_hmac_verification:
|
||||
if not signature or not timestamp:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing HMAC headers"}
|
||||
)
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid timestamp format"}
|
||||
)
|
||||
|
||||
# Read body for signature verification
|
||||
body = await request.body()
|
||||
body_dict = {}
|
||||
if body:
|
||||
try:
|
||||
import json
|
||||
body_dict = json.loads(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Verify HMAC
|
||||
# Get client secret (hardcoded for MVP, should be from DB)
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
is_valid, error_msg = hmac_manager.verify_signature(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
timestamp=timestamp_int,
|
||||
signature=signature,
|
||||
body=body_dict,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": f"HMAC verification failed: {error_msg}"}
|
||||
)
|
||||
|
||||
# Store in request state for logging
|
||||
request.state.client_id = client_id
|
||||
request.state.timestamp = timestamp
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT token verification and extraction"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip auth for public endpoints
|
||||
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing Authorization header"}
|
||||
)
|
||||
|
||||
# Parse "Bearer <token>"
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise ValueError("Invalid auth scheme")
|
||||
except (ValueError, IndexError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Authorization header format"}
|
||||
)
|
||||
|
||||
# Verify JWT
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(token)
|
||||
except ValueError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Store in request state
|
||||
request.state.user_id = token_payload.sub
|
||||
request.state.token_type = token_payload.type
|
||||
request.state.device_id = token_payload.device_id
|
||||
request.state.family_ids = token_payload.family_ids
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RBACMiddleware(BaseHTTPMiddleware):
|
||||
"""Role-Based Access Control enforcement"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_session: Any):
|
||||
super().__init__(app)
|
||||
self.db_session = db_session
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip RBAC for public endpoints
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get user context from JWT
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_ids = getattr(request.state, "family_ids", [])
|
||||
|
||||
if not user_id:
|
||||
# Already handled by JWTAuthenticationMiddleware
|
||||
return await call_next(request)
|
||||
|
||||
# Extract family_id from URL or body
|
||||
family_id = self._extract_family_id(request)
|
||||
|
||||
if family_id and family_id not in family_ids:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to this family"}
|
||||
)
|
||||
|
||||
# Load user role (would need DB query in production)
|
||||
# For MVP: Store in request state, resolved in endpoint handlers
|
||||
request.state.family_id = family_id
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
def _extract_family_id(request: Request) -> Optional[int]:
|
||||
"""Extract family_id from URL or request body"""
|
||||
# From URL path: /api/v1/families/{family_id}/...
|
||||
if "{family_id}" in request.url.path:
|
||||
# Parse from actual path
|
||||
parts = request.url.path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "families" and i + 1 < len(parts):
|
||||
try:
|
||||
return int(parts[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests and responses for audit"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log successful request
|
||||
logger.info(
|
||||
f"Endpoint={request.url.path} "
|
||||
f"Method={request.method} "
|
||||
f"Status={response.status_code} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
|
||||
# Add timing header
|
||||
response.headers["X-Response-Time"] = str(response_time_ms)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"Request error - Endpoint={request.url.path} "
|
||||
f"Error={str(e)} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
||||
"""Register all security middleware in correct order"""
|
||||
|
||||
# Order matters! Process in reverse order of registration:
|
||||
# 1. RequestLoggingMiddleware (innermost, executes last)
|
||||
# 2. RBACMiddleware
|
||||
# 3. JWTAuthenticationMiddleware
|
||||
# 4. HMACVerificationMiddleware
|
||||
# 5. RateLimitMiddleware
|
||||
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_middleware(RBACMiddleware, db_session=db_session)
|
||||
app.add_middleware(JWTAuthenticationMiddleware)
|
||||
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
324
.history/app/security/middleware_20251210221754.py
Normal file
324
.history/app/security/middleware_20251210221754.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
||||
"""
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Callable, Any
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.security.jwt_manager import jwt_manager, TokenPayload
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
response = await call_next(request)
|
||||
|
||||
# Security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting using Redis"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
self.rate_limit_requests = 100 # requests
|
||||
self.rate_limit_window = 60 # seconds
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
|
||||
# Rate limit key
|
||||
rate_key = f"rate_limit:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
current = self.redis_client.get(rate_key)
|
||||
if current and int(current) >= self.rate_limit_requests:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, self.rate_limit_window)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limiting error: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
||||
"""HMAC signature verification and anti-replay protection"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
hmac_manager.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip verification for public endpoints
|
||||
public_paths = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/telegram/start",
|
||||
"/api/v1/auth/telegram/register",
|
||||
"/api/v1/auth/telegram/authenticate",
|
||||
]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract HMAC headers
|
||||
signature = request.headers.get("X-Signature")
|
||||
timestamp = request.headers.get("X-Timestamp")
|
||||
client_id = request.headers.get("X-Client-Id", "unknown")
|
||||
|
||||
# HMAC verification is optional in MVP (configurable)
|
||||
if settings.require_hmac_verification:
|
||||
if not signature or not timestamp:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing HMAC headers"}
|
||||
)
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid timestamp format"}
|
||||
)
|
||||
|
||||
# Read body for signature verification
|
||||
body = await request.body()
|
||||
body_dict = {}
|
||||
if body:
|
||||
try:
|
||||
import json
|
||||
body_dict = json.loads(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Verify HMAC
|
||||
# Get client secret (hardcoded for MVP, should be from DB)
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
is_valid, error_msg = hmac_manager.verify_signature(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
timestamp=timestamp_int,
|
||||
signature=signature,
|
||||
body=body_dict,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": f"HMAC verification failed: {error_msg}"}
|
||||
)
|
||||
|
||||
# Store in request state for logging
|
||||
request.state.client_id = client_id
|
||||
request.state.timestamp = timestamp
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT token verification and extraction"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip auth for public endpoints
|
||||
public_paths = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/telegram/start",
|
||||
"/api/v1/auth/telegram/register",
|
||||
"/api/v1/auth/telegram/authenticate",
|
||||
]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing Authorization header"}
|
||||
)
|
||||
|
||||
# Parse "Bearer <token>"
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise ValueError("Invalid auth scheme")
|
||||
except (ValueError, IndexError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Authorization header format"}
|
||||
)
|
||||
|
||||
# Verify JWT
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(token)
|
||||
except ValueError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Store in request state
|
||||
request.state.user_id = token_payload.sub
|
||||
request.state.token_type = token_payload.type
|
||||
request.state.device_id = token_payload.device_id
|
||||
request.state.family_ids = token_payload.family_ids
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RBACMiddleware(BaseHTTPMiddleware):
|
||||
"""Role-Based Access Control enforcement"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_session: Any):
|
||||
super().__init__(app)
|
||||
self.db_session = db_session
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip RBAC for public endpoints
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get user context from JWT
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_ids = getattr(request.state, "family_ids", [])
|
||||
|
||||
if not user_id:
|
||||
# Already handled by JWTAuthenticationMiddleware
|
||||
return await call_next(request)
|
||||
|
||||
# Extract family_id from URL or body
|
||||
family_id = self._extract_family_id(request)
|
||||
|
||||
if family_id and family_id not in family_ids:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to this family"}
|
||||
)
|
||||
|
||||
# Load user role (would need DB query in production)
|
||||
# For MVP: Store in request state, resolved in endpoint handlers
|
||||
request.state.family_id = family_id
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
def _extract_family_id(request: Request) -> Optional[int]:
|
||||
"""Extract family_id from URL or request body"""
|
||||
# From URL path: /api/v1/families/{family_id}/...
|
||||
if "{family_id}" in request.url.path:
|
||||
# Parse from actual path
|
||||
parts = request.url.path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "families" and i + 1 < len(parts):
|
||||
try:
|
||||
return int(parts[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests and responses for audit"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log successful request
|
||||
logger.info(
|
||||
f"Endpoint={request.url.path} "
|
||||
f"Method={request.method} "
|
||||
f"Status={response.status_code} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
|
||||
# Add timing header
|
||||
response.headers["X-Response-Time"] = str(response_time_ms)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"Request error - Endpoint={request.url.path} "
|
||||
f"Error={str(e)} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
||||
"""Register all security middleware in correct order"""
|
||||
|
||||
# Order matters! Process in reverse order of registration:
|
||||
# 1. RequestLoggingMiddleware (innermost, executes last)
|
||||
# 2. RBACMiddleware
|
||||
# 3. JWTAuthenticationMiddleware
|
||||
# 4. HMACVerificationMiddleware
|
||||
# 5. RateLimitMiddleware
|
||||
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_middleware(RBACMiddleware, db_session=db_session)
|
||||
app.add_middleware(JWTAuthenticationMiddleware)
|
||||
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
324
.history/app/security/middleware_20251210221758.py
Normal file
324
.history/app/security/middleware_20251210221758.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
||||
"""
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Callable, Any
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.security.jwt_manager import jwt_manager, TokenPayload
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
response = await call_next(request)
|
||||
|
||||
# Security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting using Redis"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
self.rate_limit_requests = 100 # requests
|
||||
self.rate_limit_window = 60 # seconds
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
|
||||
# Rate limit key
|
||||
rate_key = f"rate_limit:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
current = self.redis_client.get(rate_key)
|
||||
if current and int(current) >= self.rate_limit_requests:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, self.rate_limit_window)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limiting error: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
||||
"""HMAC signature verification and anti-replay protection"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
hmac_manager.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip verification for public endpoints
|
||||
public_paths = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/telegram/start",
|
||||
"/api/v1/auth/telegram/register",
|
||||
"/api/v1/auth/telegram/authenticate",
|
||||
]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract HMAC headers
|
||||
signature = request.headers.get("X-Signature")
|
||||
timestamp = request.headers.get("X-Timestamp")
|
||||
client_id = request.headers.get("X-Client-Id", "unknown")
|
||||
|
||||
# HMAC verification is optional in MVP (configurable)
|
||||
if settings.require_hmac_verification:
|
||||
if not signature or not timestamp:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing HMAC headers"}
|
||||
)
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid timestamp format"}
|
||||
)
|
||||
|
||||
# Read body for signature verification
|
||||
body = await request.body()
|
||||
body_dict = {}
|
||||
if body:
|
||||
try:
|
||||
import json
|
||||
body_dict = json.loads(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Verify HMAC
|
||||
# Get client secret (hardcoded for MVP, should be from DB)
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
is_valid, error_msg = hmac_manager.verify_signature(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
timestamp=timestamp_int,
|
||||
signature=signature,
|
||||
body=body_dict,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": f"HMAC verification failed: {error_msg}"}
|
||||
)
|
||||
|
||||
# Store in request state for logging
|
||||
request.state.client_id = client_id
|
||||
request.state.timestamp = timestamp
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT token verification and extraction"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip auth for public endpoints
|
||||
public_paths = [
|
||||
"/health",
|
||||
"/docs",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/telegram/start",
|
||||
"/api/v1/auth/telegram/register",
|
||||
"/api/v1/auth/telegram/authenticate",
|
||||
]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing Authorization header"}
|
||||
)
|
||||
|
||||
# Parse "Bearer <token>"
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise ValueError("Invalid auth scheme")
|
||||
except (ValueError, IndexError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Authorization header format"}
|
||||
)
|
||||
|
||||
# Verify JWT
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(token)
|
||||
except ValueError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Store in request state
|
||||
request.state.user_id = token_payload.sub
|
||||
request.state.token_type = token_payload.type
|
||||
request.state.device_id = token_payload.device_id
|
||||
request.state.family_ids = token_payload.family_ids
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RBACMiddleware(BaseHTTPMiddleware):
|
||||
"""Role-Based Access Control enforcement"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_session: Any):
|
||||
super().__init__(app)
|
||||
self.db_session = db_session
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip RBAC for public endpoints
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get user context from JWT
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_ids = getattr(request.state, "family_ids", [])
|
||||
|
||||
if not user_id:
|
||||
# Already handled by JWTAuthenticationMiddleware
|
||||
return await call_next(request)
|
||||
|
||||
# Extract family_id from URL or body
|
||||
family_id = self._extract_family_id(request)
|
||||
|
||||
if family_id and family_id not in family_ids:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to this family"}
|
||||
)
|
||||
|
||||
# Load user role (would need DB query in production)
|
||||
# For MVP: Store in request state, resolved in endpoint handlers
|
||||
request.state.family_id = family_id
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
def _extract_family_id(request: Request) -> Optional[int]:
|
||||
"""Extract family_id from URL or request body"""
|
||||
# From URL path: /api/v1/families/{family_id}/...
|
||||
if "{family_id}" in request.url.path:
|
||||
# Parse from actual path
|
||||
parts = request.url.path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "families" and i + 1 < len(parts):
|
||||
try:
|
||||
return int(parts[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests and responses for audit"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log successful request
|
||||
logger.info(
|
||||
f"Endpoint={request.url.path} "
|
||||
f"Method={request.method} "
|
||||
f"Status={response.status_code} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
|
||||
# Add timing header
|
||||
response.headers["X-Response-Time"] = str(response_time_ms)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"Request error - Endpoint={request.url.path} "
|
||||
f"Error={str(e)} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
||||
"""Register all security middleware in correct order"""
|
||||
|
||||
# Order matters! Process in reverse order of registration:
|
||||
# 1. RequestLoggingMiddleware (innermost, executes last)
|
||||
# 2. RBACMiddleware
|
||||
# 3. JWTAuthenticationMiddleware
|
||||
# 4. HMACVerificationMiddleware
|
||||
# 5. RateLimitMiddleware
|
||||
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_middleware(RBACMiddleware, db_session=db_session)
|
||||
app.add_middleware(JWTAuthenticationMiddleware)
|
||||
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
Reference in New Issue
Block a user