All checks were successful
continuous-integration/drone/push Build is passing
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
"""
|
|
Authentication utilities for all services.
|
|
This module provides common authentication functionality to avoid circular imports.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from jwt.exceptions import InvalidTokenError
|
|
from passlib.context import CryptContext
|
|
|
|
from shared.config import settings
|
|
|
|
# Suppress bcrypt version warnings
|
|
logging.getLogger("passlib").setLevel(logging.ERROR)
|
|
|
|
# Password hashing
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# Bearer token scheme
|
|
security = HTTPBearer()
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash. Handle bcrypt compatibility issues."""
|
|
try:
|
|
# Truncate password to 72 bytes for consistency
|
|
password_bytes = plain_password.encode('utf-8')
|
|
if len(password_bytes) > 72:
|
|
plain_password = password_bytes[:72].decode('utf-8', errors='ignore')
|
|
return pwd_context.verify(plain_password, hashed_password)
|
|
except Exception as e:
|
|
logging.error(f"Error verifying password: {e}")
|
|
return False
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
"""Get password hash. Truncate password to 72 bytes if necessary for bcrypt compatibility."""
|
|
try:
|
|
# bcrypt has a 72-byte limit, so truncate if necessary
|
|
password_bytes = password.encode('utf-8')
|
|
if len(password_bytes) > 72:
|
|
logging.warning("Password exceeds bcrypt limit of 72 bytes. Truncating.")
|
|
password = password_bytes[:70].decode('utf-8', errors='ignore')
|
|
return pwd_context.hash(password)
|
|
except Exception as e:
|
|
# Handle bcrypt compatibility issues
|
|
logging.error(f"Error hashing password: {e}")
|
|
raise ValueError("Password hashing failed. Please use a shorter password.")
|
|
|
|
|
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
|
"""Create access token."""
|
|
to_encode = data.copy()
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
|
to_encode.update({"exp": expire})
|
|
encoded_jwt = jwt.encode(
|
|
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
|
)
|
|
return encoded_jwt
|
|
|
|
|
|
def verify_token(token: str) -> Optional[dict]:
|
|
"""Verify and decode JWT token."""
|
|
try:
|
|
payload = jwt.decode(
|
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
|
)
|
|
user_id: str = payload.get("sub")
|
|
if user_id is None:
|
|
return None
|
|
return {"user_id": int(user_id), "email": payload.get("email")}
|
|
except InvalidTokenError:
|
|
return None
|
|
|
|
|
|
async def get_current_user_from_token(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
) -> dict:
|
|
"""Get current user from JWT token."""
|
|
token = credentials.credentials
|
|
user_data = verify_token(token)
|
|
|
|
if user_data is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
return user_data
|