init commit
This commit is contained in:
1
app/security/__init__.py
Normal file
1
app/security/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Security module: JWT, HMAC, RBAC
|
||||
145
app/security/hmac_manager.py
Normal file
145
app/security/hmac_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
HMAC Signature Verification - Replay Attack Prevention & Request Integrity
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlencode
|
||||
from app.core.config import settings
|
||||
import redis
|
||||
|
||||
|
||||
class HMACManager:
|
||||
"""
|
||||
Request signing and verification using HMAC-SHA256.
|
||||
|
||||
Signature Format:
|
||||
────────────────────────────────────────────────────
|
||||
base_string = METHOD + ENDPOINT + TIMESTAMP + hash(BODY)
|
||||
signature = HMAC_SHA256(base_string, client_secret)
|
||||
|
||||
Headers Required:
|
||||
- X-Signature: base64(signature)
|
||||
- X-Timestamp: unix timestamp (seconds)
|
||||
- X-Client-Id: client identifier
|
||||
|
||||
Anti-Replay Protection:
|
||||
- Check timestamp freshness (±30 seconds)
|
||||
- Store signature hash in Redis with 1-minute TTL
|
||||
- Reject duplicate signatures (nonce check)
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
TIMESTAMP_TOLERANCE_SECONDS = 30
|
||||
REPLAY_NONCE_TTL_SECONDS = 60
|
||||
|
||||
def __init__(self, redis_client: redis.Redis = None):
|
||||
self.redis_client = redis_client
|
||||
self.algorithm = "sha256"
|
||||
|
||||
def create_signature(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
body: dict = None,
|
||||
client_secret: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create HMAC signature for request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE)
|
||||
endpoint: API endpoint path (/api/v1/transactions)
|
||||
timestamp: Unix timestamp
|
||||
body: Request body dictionary
|
||||
client_secret: Shared secret key
|
||||
|
||||
Returns:
|
||||
Base64-encoded signature
|
||||
"""
|
||||
if client_secret is None:
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
# Create base string
|
||||
base_string = self._build_base_string(method, endpoint, timestamp, body)
|
||||
|
||||
# Generate HMAC
|
||||
signature = hmac.new(
|
||||
client_secret.encode(),
|
||||
base_string.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return signature
|
||||
|
||||
def verify_signature(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
signature: str,
|
||||
body: dict = None,
|
||||
client_secret: str = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify HMAC signature and check for replay attacks.
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if client_secret is None:
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
# Step 1: Check timestamp freshness
|
||||
now = datetime.utcnow().timestamp()
|
||||
time_diff = abs(now - timestamp)
|
||||
|
||||
if time_diff > self.TIMESTAMP_TOLERANCE_SECONDS:
|
||||
return False, f"Timestamp too old (diff: {time_diff}s)"
|
||||
|
||||
# Step 2: Verify signature match
|
||||
expected_signature = self.create_signature(
|
||||
method, endpoint, timestamp, body, client_secret
|
||||
)
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
return False, "Signature mismatch"
|
||||
|
||||
# Step 3: Check for replay (signature already used)
|
||||
if self.redis_client:
|
||||
nonce_key = f"hmac:nonce:{signature}"
|
||||
if self.redis_client.exists(nonce_key):
|
||||
return False, "Signature already used (replay attack)"
|
||||
|
||||
# Store nonce
|
||||
self.redis_client.setex(nonce_key, self.REPLAY_NONCE_TTL_SECONDS, "1")
|
||||
|
||||
return True, ""
|
||||
|
||||
def _build_base_string(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
body: dict = None,
|
||||
) -> str:
|
||||
"""Construct base string for signing"""
|
||||
# Normalize method
|
||||
method = method.upper()
|
||||
|
||||
# Hash body (sorted JSON)
|
||||
body_hash = ""
|
||||
if body:
|
||||
body_json = json.dumps(body, sort_keys=True, separators=(',', ':'))
|
||||
body_hash = hashlib.sha256(body_json.encode()).hexdigest()
|
||||
|
||||
# Base string format
|
||||
base_string = f"{method}:{endpoint}:{timestamp}:{body_hash}"
|
||||
return base_string
|
||||
|
||||
|
||||
# Singleton instance
|
||||
hmac_manager = HMACManager()
|
||||
149
app/security/jwt_manager.py
Normal file
149
app/security/jwt_manager.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
JWT Token Management - Access & Refresh Token Handling
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
ACCESS = "access"
|
||||
REFRESH = "refresh"
|
||||
SERVICE = "service" # For bot/workers
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT Token Payload Structure"""
|
||||
sub: int # user_id
|
||||
type: TokenType
|
||||
device_id: Optional[str] = None
|
||||
scope: str = "default" # For granular permissions
|
||||
family_ids: list[int] = [] # Accessible families
|
||||
iat: int # issued at
|
||||
exp: int # expiration
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""
|
||||
JWT token generation, validation, and management.
|
||||
|
||||
Algorithms:
|
||||
- Production: RS256 (asymmetric) - more secure, scalable
|
||||
- MVP: HS256 (symmetric) - simpler setup
|
||||
"""
|
||||
|
||||
# Token lifetimes (configurable in settings)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # Short-lived
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30 # Long-lived
|
||||
SERVICE_TOKEN_EXPIRE_HOURS = 8760 # 1 year
|
||||
|
||||
def __init__(self, secret_key: str = None):
|
||||
self.secret_key = secret_key or settings.jwt_secret_key
|
||||
self.algorithm = "HS256"
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
user_id: int,
|
||||
device_id: Optional[str] = None,
|
||||
family_ids: list[int] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate short-lived access token"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(minutes=self.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
return self._create_token(
|
||||
user_id=user_id,
|
||||
token_type=TokenType.ACCESS,
|
||||
expires_delta=expires_delta,
|
||||
device_id=device_id,
|
||||
family_ids=family_ids or [],
|
||||
)
|
||||
|
||||
def create_refresh_token(
|
||||
self,
|
||||
user_id: int,
|
||||
device_id: Optional[str] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate long-lived refresh token"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
return self._create_token(
|
||||
user_id=user_id,
|
||||
token_type=TokenType.REFRESH,
|
||||
expires_delta=expires_delta,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
def create_service_token(
|
||||
self,
|
||||
service_name: str,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate service-to-service token (e.g., for bot)"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(hours=self.SERVICE_TOKEN_EXPIRE_HOURS)
|
||||
|
||||
now = datetime.utcnow()
|
||||
expire = now + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": f"service:{service_name}",
|
||||
"type": TokenType.SERVICE,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def _create_token(
|
||||
self,
|
||||
user_id: int,
|
||||
token_type: TokenType,
|
||||
expires_delta: timedelta,
|
||||
device_id: Optional[str] = None,
|
||||
family_ids: list[int] = None,
|
||||
) -> str:
|
||||
"""Internal token creation"""
|
||||
now = datetime.utcnow()
|
||||
expire = now + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"type": token_type.value,
|
||||
"device_id": device_id,
|
||||
"family_ids": family_ids or [],
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> TokenPayload:
|
||||
"""
|
||||
Verify token signature and expiration.
|
||||
|
||||
Raises:
|
||||
- jwt.InvalidTokenError
|
||||
- jwt.ExpiredSignatureError
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verification (for debugging only)"""
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
|
||||
# Singleton instance
|
||||
jwt_manager = JWTManager()
|
||||
308
app/security/middleware.py
Normal file
308
app/security/middleware.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
||||
"""
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Callable, Any
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.security.jwt_manager import jwt_manager, TokenPayload
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
response = await call_next(request)
|
||||
|
||||
# Security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting using Redis"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
self.rate_limit_requests = 100 # requests
|
||||
self.rate_limit_window = 60 # seconds
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
|
||||
# Rate limit key
|
||||
rate_key = f"rate_limit:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
current = self.redis_client.get(rate_key)
|
||||
if current and int(current) >= self.rate_limit_requests:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, self.rate_limit_window)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limiting error: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
||||
"""HMAC signature verification and anti-replay protection"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
hmac_manager.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip verification for public endpoints
|
||||
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract HMAC headers
|
||||
signature = request.headers.get("X-Signature")
|
||||
timestamp = request.headers.get("X-Timestamp")
|
||||
client_id = request.headers.get("X-Client-Id", "unknown")
|
||||
|
||||
# HMAC verification is optional in MVP (configurable)
|
||||
if settings.require_hmac_verification:
|
||||
if not signature or not timestamp:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing HMAC headers"}
|
||||
)
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid timestamp format"}
|
||||
)
|
||||
|
||||
# Read body for signature verification
|
||||
body = await request.body()
|
||||
body_dict = {}
|
||||
if body:
|
||||
try:
|
||||
import json
|
||||
body_dict = json.loads(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Verify HMAC
|
||||
# Get client secret (hardcoded for MVP, should be from DB)
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
is_valid, error_msg = hmac_manager.verify_signature(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
timestamp=timestamp_int,
|
||||
signature=signature,
|
||||
body=body_dict,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": f"HMAC verification failed: {error_msg}"}
|
||||
)
|
||||
|
||||
# Store in request state for logging
|
||||
request.state.client_id = client_id
|
||||
request.state.timestamp = timestamp
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT token verification and extraction"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip auth for public endpoints
|
||||
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing Authorization header"}
|
||||
)
|
||||
|
||||
# Parse "Bearer <token>"
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise ValueError("Invalid auth scheme")
|
||||
except (ValueError, IndexError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Authorization header format"}
|
||||
)
|
||||
|
||||
# Verify JWT
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(token)
|
||||
except ValueError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Store in request state
|
||||
request.state.user_id = token_payload.sub
|
||||
request.state.token_type = token_payload.type
|
||||
request.state.device_id = token_payload.device_id
|
||||
request.state.family_ids = token_payload.family_ids
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RBACMiddleware(BaseHTTPMiddleware):
|
||||
"""Role-Based Access Control enforcement"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_session: Any):
|
||||
super().__init__(app)
|
||||
self.db_session = db_session
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip RBAC for public endpoints
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get user context from JWT
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_ids = getattr(request.state, "family_ids", [])
|
||||
|
||||
if not user_id:
|
||||
# Already handled by JWTAuthenticationMiddleware
|
||||
return await call_next(request)
|
||||
|
||||
# Extract family_id from URL or body
|
||||
family_id = self._extract_family_id(request)
|
||||
|
||||
if family_id and family_id not in family_ids:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to this family"}
|
||||
)
|
||||
|
||||
# Load user role (would need DB query in production)
|
||||
# For MVP: Store in request state, resolved in endpoint handlers
|
||||
request.state.family_id = family_id
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
def _extract_family_id(request: Request) -> Optional[int]:
|
||||
"""Extract family_id from URL or request body"""
|
||||
# From URL path: /api/v1/families/{family_id}/...
|
||||
if "{family_id}" in request.url.path:
|
||||
# Parse from actual path
|
||||
parts = request.url.path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "families" and i + 1 < len(parts):
|
||||
try:
|
||||
return int(parts[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests and responses for audit"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log successful request
|
||||
logger.info(
|
||||
f"Endpoint={request.url.path} "
|
||||
f"Method={request.method} "
|
||||
f"Status={response.status_code} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
|
||||
# Add timing header
|
||||
response.headers["X-Response-Time"] = str(response_time_ms)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"Request error - Endpoint={request.url.path} "
|
||||
f"Error={str(e)} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
||||
"""Register all security middleware in correct order"""
|
||||
|
||||
# Order matters! Process in reverse order of registration:
|
||||
# 1. RequestLoggingMiddleware (innermost, executes last)
|
||||
# 2. RBACMiddleware
|
||||
# 3. JWTAuthenticationMiddleware
|
||||
# 4. HMACVerificationMiddleware
|
||||
# 5. RateLimitMiddleware
|
||||
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_middleware(RBACMiddleware, db_session=db_session)
|
||||
app.add_middleware(JWTAuthenticationMiddleware)
|
||||
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
228
app/security/rbac.py
Normal file
228
app/security/rbac.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Role-Based Access Control (RBAC) - Authorization Engine
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional, Set, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class MemberRole(str, Enum):
|
||||
"""Family member roles with hierarchy"""
|
||||
OWNER = "owner" # Full access
|
||||
ADULT = "adult" # Can create/edit own transactions
|
||||
MEMBER = "member" # Can create/edit own transactions, restricted budget
|
||||
CHILD = "child" # Limited access, read mostly
|
||||
READ_ONLY = "read_only" # Audit/observer only
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Fine-grained permissions"""
|
||||
# Transaction permissions
|
||||
CREATE_TRANSACTION = "create_transaction"
|
||||
EDIT_OWN_TRANSACTION = "edit_own_transaction"
|
||||
EDIT_ANY_TRANSACTION = "edit_any_transaction"
|
||||
DELETE_OWN_TRANSACTION = "delete_own_transaction"
|
||||
DELETE_ANY_TRANSACTION = "delete_any_transaction"
|
||||
APPROVE_TRANSACTION = "approve_transaction"
|
||||
|
||||
# Wallet permissions
|
||||
CREATE_WALLET = "create_wallet"
|
||||
EDIT_WALLET = "edit_wallet"
|
||||
DELETE_WALLET = "delete_wallet"
|
||||
VIEW_WALLET_BALANCE = "view_wallet_balance"
|
||||
|
||||
# Budget permissions
|
||||
CREATE_BUDGET = "create_budget"
|
||||
EDIT_BUDGET = "edit_budget"
|
||||
DELETE_BUDGET = "delete_budget"
|
||||
|
||||
# Goal permissions
|
||||
CREATE_GOAL = "create_goal"
|
||||
EDIT_GOAL = "edit_goal"
|
||||
DELETE_GOAL = "delete_goal"
|
||||
|
||||
# Category permissions
|
||||
CREATE_CATEGORY = "create_category"
|
||||
EDIT_CATEGORY = "edit_category"
|
||||
DELETE_CATEGORY = "delete_category"
|
||||
|
||||
# Member management
|
||||
INVITE_MEMBERS = "invite_members"
|
||||
EDIT_MEMBER_ROLE = "edit_member_role"
|
||||
REMOVE_MEMBER = "remove_member"
|
||||
|
||||
# Family settings
|
||||
EDIT_FAMILY_SETTINGS = "edit_family_settings"
|
||||
DELETE_FAMILY = "delete_family"
|
||||
|
||||
# Audit & reports
|
||||
VIEW_AUDIT_LOG = "view_audit_log"
|
||||
EXPORT_DATA = "export_data"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserContext:
|
||||
"""Request context with authorization info"""
|
||||
user_id: int
|
||||
family_id: int
|
||||
role: MemberRole
|
||||
permissions: Set[Permission]
|
||||
family_ids: list[int] # All accessible families
|
||||
device_id: Optional[str] = None
|
||||
client_id: Optional[str] = None # "telegram_bot", "web_frontend", etc.
|
||||
|
||||
|
||||
class RBACEngine:
|
||||
"""
|
||||
Role-Based Access Control with permission inheritance.
|
||||
"""
|
||||
|
||||
# Define role -> permissions mapping
|
||||
ROLE_PERMISSIONS: Dict[MemberRole, Set[Permission]] = {
|
||||
MemberRole.OWNER: {
|
||||
# All permissions
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.EDIT_ANY_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.DELETE_ANY_TRANSACTION,
|
||||
Permission.APPROVE_TRANSACTION,
|
||||
Permission.CREATE_WALLET,
|
||||
Permission.EDIT_WALLET,
|
||||
Permission.DELETE_WALLET,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.CREATE_BUDGET,
|
||||
Permission.EDIT_BUDGET,
|
||||
Permission.DELETE_BUDGET,
|
||||
Permission.CREATE_GOAL,
|
||||
Permission.EDIT_GOAL,
|
||||
Permission.DELETE_GOAL,
|
||||
Permission.CREATE_CATEGORY,
|
||||
Permission.EDIT_CATEGORY,
|
||||
Permission.DELETE_CATEGORY,
|
||||
Permission.INVITE_MEMBERS,
|
||||
Permission.EDIT_MEMBER_ROLE,
|
||||
Permission.REMOVE_MEMBER,
|
||||
Permission.EDIT_FAMILY_SETTINGS,
|
||||
Permission.DELETE_FAMILY,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
Permission.EXPORT_DATA,
|
||||
},
|
||||
|
||||
MemberRole.ADULT: {
|
||||
# Can manage finances and invite others
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.APPROVE_TRANSACTION,
|
||||
Permission.CREATE_WALLET,
|
||||
Permission.EDIT_WALLET,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.CREATE_BUDGET,
|
||||
Permission.EDIT_BUDGET,
|
||||
Permission.CREATE_GOAL,
|
||||
Permission.EDIT_GOAL,
|
||||
Permission.CREATE_CATEGORY,
|
||||
Permission.INVITE_MEMBERS,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
Permission.EXPORT_DATA,
|
||||
},
|
||||
|
||||
MemberRole.MEMBER: {
|
||||
# Can create/view transactions
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
|
||||
MemberRole.CHILD: {
|
||||
# Limited read access
|
||||
Permission.CREATE_TRANSACTION, # Limited to own
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
|
||||
MemberRole.READ_ONLY: {
|
||||
# Audit/observer only
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(role: MemberRole) -> Set[Permission]:
|
||||
"""Get permissions for a role"""
|
||||
return RBACEngine.ROLE_PERMISSIONS.get(role, set())
|
||||
|
||||
@staticmethod
|
||||
def has_permission(user_context: UserContext, permission: Permission) -> bool:
|
||||
"""Check if user has specific permission"""
|
||||
return permission in user_context.permissions
|
||||
|
||||
@staticmethod
|
||||
def check_permission(
|
||||
user_context: UserContext,
|
||||
required_permission: Permission,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check permission and optionally raise exception.
|
||||
|
||||
Raises:
|
||||
- PermissionError if raise_exception=True and user lacks permission
|
||||
"""
|
||||
has_perm = RBACEngine.has_permission(user_context, required_permission)
|
||||
|
||||
if not has_perm and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} lacks permission: {required_permission.value}"
|
||||
)
|
||||
|
||||
return has_perm
|
||||
|
||||
@staticmethod
|
||||
def check_family_access(
|
||||
user_context: UserContext,
|
||||
requested_family_id: int,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""Verify user has access to requested family"""
|
||||
has_access = requested_family_id in user_context.family_ids
|
||||
|
||||
if not has_access and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} cannot access family {requested_family_id}"
|
||||
)
|
||||
|
||||
return has_access
|
||||
|
||||
@staticmethod
|
||||
def check_resource_ownership(
|
||||
user_context: UserContext,
|
||||
owner_id: int,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""Check if user is owner of resource"""
|
||||
is_owner = user_context.user_id == owner_id
|
||||
|
||||
if not is_owner and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} is not owner of resource (owner: {owner_id})"
|
||||
)
|
||||
|
||||
return is_owner
|
||||
|
||||
|
||||
# Policy definitions (for advanced use)
|
||||
POLICIES = {
|
||||
"transaction_approval_required": {
|
||||
"conditions": ["amount > 500", "role != owner"],
|
||||
"action": "require_approval"
|
||||
},
|
||||
"restrict_child_budget": {
|
||||
"conditions": ["role == child"],
|
||||
"action": "limit_to_100_per_day"
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user