Files
chat/shared/auth.py
Andrew K. Choi e5aa933cf9
All checks were successful
continuous-integration/drone/push Build is passing
bcrypt fix
2025-09-26 07:07:48 +09:00

98 lines
3.2 KiB
Python

"""
Authentication utilities for all services.
This module provides common authentication functionality to avoid circular imports.
"""
import logging
from datetime import datetime, timedelta
from typing import Optional
import jwt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from shared.config import settings
# Suppress bcrypt version warnings
logging.getLogger("passlib").setLevel(logging.ERROR)
# Password hashing
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# Bearer token scheme
security = HTTPBearer()
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash. Handle bcrypt compatibility issues."""
try:
# Truncate password to 72 bytes for consistency
password_bytes = plain_password.encode('utf-8')
if len(password_bytes) > 72:
plain_password = password_bytes[:72].decode('utf-8', errors='ignore')
return pwd_context.verify(plain_password, hashed_password)
except Exception as e:
logging.error(f"Error verifying password: {e}")
return False
def get_password_hash(password: str) -> str:
"""Get password hash. Truncate password to 72 bytes if necessary for bcrypt compatibility."""
try:
# bcrypt has a 72-byte limit, so truncate if necessary
password_bytes = password.encode('utf-8')
if len(password_bytes) > 72:
password = password_bytes[:72].decode('utf-8', errors='ignore')
return pwd_context.hash(password)
except Exception as e:
# Handle bcrypt compatibility issues
logging.error(f"Error hashing password: {e}")
raise ValueError("Password hashing failed. Please use a shorter password.")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def verify_token(token: str) -> Optional[dict]:
"""Verify and decode JWT token."""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
user_id: str = payload.get("sub")
if user_id is None:
return None
return {"user_id": int(user_id), "email": payload.get("email")}
except InvalidTokenError:
return None
async def get_current_user_from_token(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> dict:
"""Get current user from JWT token."""
token = credentials.credentials
user_data = verify_token(token)
if user_data is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user_data