150 lines
4.5 KiB
Python
150 lines
4.5 KiB
Python
"""
|
|
JWT Token Management - Access & Refresh Token Handling
|
|
"""
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
from enum import Enum
|
|
import jwt
|
|
from pydantic import BaseModel
|
|
from app.core.config import settings
|
|
|
|
|
|
class TokenType(str, Enum):
|
|
ACCESS = "access"
|
|
REFRESH = "refresh"
|
|
SERVICE = "service" # For bot/workers
|
|
|
|
|
|
class TokenPayload(BaseModel):
|
|
"""JWT Token Payload Structure"""
|
|
sub: int # user_id
|
|
type: TokenType
|
|
device_id: Optional[str] = None
|
|
scope: str = "default" # For granular permissions
|
|
family_ids: list[int] = [] # Accessible families
|
|
iat: int # issued at
|
|
exp: int # expiration
|
|
|
|
|
|
class JWTManager:
|
|
"""
|
|
JWT token generation, validation, and management.
|
|
|
|
Algorithms:
|
|
- Production: RS256 (asymmetric) - more secure, scalable
|
|
- MVP: HS256 (symmetric) - simpler setup
|
|
"""
|
|
|
|
# Token lifetimes (configurable in settings)
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # Short-lived
|
|
REFRESH_TOKEN_EXPIRE_DAYS = 30 # Long-lived
|
|
SERVICE_TOKEN_EXPIRE_HOURS = 8760 # 1 year
|
|
|
|
def __init__(self, secret_key: str = None):
|
|
self.secret_key = secret_key or settings.jwt_secret_key
|
|
self.algorithm = "HS256"
|
|
|
|
def create_access_token(
|
|
self,
|
|
user_id: int,
|
|
device_id: Optional[str] = None,
|
|
family_ids: list[int] = None,
|
|
expires_delta: Optional[timedelta] = None,
|
|
) -> str:
|
|
"""Generate short-lived access token"""
|
|
if expires_delta is None:
|
|
expires_delta = timedelta(minutes=self.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
return self._create_token(
|
|
user_id=user_id,
|
|
token_type=TokenType.ACCESS,
|
|
expires_delta=expires_delta,
|
|
device_id=device_id,
|
|
family_ids=family_ids or [],
|
|
)
|
|
|
|
def create_refresh_token(
|
|
self,
|
|
user_id: int,
|
|
device_id: Optional[str] = None,
|
|
expires_delta: Optional[timedelta] = None,
|
|
) -> str:
|
|
"""Generate long-lived refresh token"""
|
|
if expires_delta is None:
|
|
expires_delta = timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS)
|
|
|
|
return self._create_token(
|
|
user_id=user_id,
|
|
token_type=TokenType.REFRESH,
|
|
expires_delta=expires_delta,
|
|
device_id=device_id,
|
|
)
|
|
|
|
def create_service_token(
|
|
self,
|
|
service_name: str,
|
|
expires_delta: Optional[timedelta] = None,
|
|
) -> str:
|
|
"""Generate service-to-service token (e.g., for bot)"""
|
|
if expires_delta is None:
|
|
expires_delta = timedelta(hours=self.SERVICE_TOKEN_EXPIRE_HOURS)
|
|
|
|
now = datetime.utcnow()
|
|
expire = now + expires_delta
|
|
|
|
payload = {
|
|
"sub": f"service:{service_name}",
|
|
"type": TokenType.SERVICE,
|
|
"iat": int(now.timestamp()),
|
|
"exp": int(expire.timestamp()),
|
|
}
|
|
|
|
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
|
|
def _create_token(
|
|
self,
|
|
user_id: int,
|
|
token_type: TokenType,
|
|
expires_delta: timedelta,
|
|
device_id: Optional[str] = None,
|
|
family_ids: list[int] = None,
|
|
) -> str:
|
|
"""Internal token creation"""
|
|
now = datetime.utcnow()
|
|
expire = now + expires_delta
|
|
|
|
payload = {
|
|
"sub": user_id,
|
|
"type": token_type.value,
|
|
"device_id": device_id,
|
|
"family_ids": family_ids or [],
|
|
"iat": int(now.timestamp()),
|
|
"exp": int(expire.timestamp()),
|
|
}
|
|
|
|
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
|
|
|
def verify_token(self, token: str) -> TokenPayload:
|
|
"""
|
|
Verify token signature and expiration.
|
|
|
|
Raises:
|
|
- jwt.InvalidTokenError
|
|
- jwt.ExpiredSignatureError
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
|
return TokenPayload(**payload)
|
|
except jwt.ExpiredSignatureError:
|
|
raise ValueError("Token has expired")
|
|
except jwt.InvalidTokenError:
|
|
raise ValueError("Invalid token")
|
|
|
|
def decode_token(self, token: str) -> Dict[str, Any]:
|
|
"""Decode token without verification (for debugging only)"""
|
|
return jwt.decode(token, options={"verify_signature": False})
|
|
|
|
|
|
# Singleton instance
|
|
jwt_manager = JWTManager()
|