from __future__ import annotations import time from collections import defaultdict, deque from collections.abc import Hashable from fastapi import HTTPException, Request, status from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.models.user import User BucketKey = tuple[str, Hashable] _buckets: dict[BucketKey, deque[float]] = defaultdict(deque) _redis_client = None def reset_rate_limit_state() -> None: _buckets.clear() async def check_rate_limit( *, scope: str, limit: int, window_seconds: int, request: Request | None = None, user: User | None = None, session: AsyncSession | None = None, ) -> None: identifiers: list[Hashable] = [] if user is not None: identifiers.append(f"user:{user.id}") identifiers.append(f"telegram:{user.telegram_id}") if request is not None and request.client is not None: identifiers.append(f"ip:{request.client.host}") if not identifiers: identifiers.append("anonymous") if settings.redis_url: allowed = await check_redis_rate_limit(scope, identifiers, limit, window_seconds) if not allowed: await log_rate_limit_event(session, scope=scope, identifier="redis") raise_rate_limit(scope, window_seconds) return now = time.monotonic() for identifier in identifiers: key = (scope, identifier) bucket = _buckets[key] while bucket and now - bucket[0] > window_seconds: bucket.popleft() if len(bucket) >= limit: await log_rate_limit_event(session, scope=scope, identifier=str(identifier)) raise_rate_limit(scope, window_seconds) for identifier in identifiers: _buckets[(scope, identifier)].append(now) def raise_rate_limit(scope: str, window_seconds: int) -> None: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail={ "code": "rate_limit_exceeded", "message": "Слишком много запросов. Попробуйте чуть позже.", "scope": scope, "retry_after_seconds": window_seconds, }, ) async def get_redis_client(): global _redis_client if _redis_client is not None: return _redis_client try: from redis.asyncio import Redis except ImportError: return None _redis_client = Redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=True) return _redis_client async def check_redis_rate_limit( scope: str, identifiers: list[Hashable], limit: int, window_seconds: int, ) -> bool: client = await get_redis_client() if client is None: return True now_window = int(time.time() // window_seconds) keys = [f"rl:{scope}:{identifier}:{now_window}" for identifier in identifiers] pipe = client.pipeline() for key in keys: pipe.incr(key) pipe.expire(key, window_seconds * 2) results = await pipe.execute() counts = [int(results[index]) for index in range(0, len(results), 2)] return all(count <= limit for count in counts) async def log_rate_limit_event( session: AsyncSession | None, *, scope: str, identifier: str, ) -> None: if session is None: return from app.models.car import AuditLog session.add( AuditLog( actor_user_id=None, actor_role="system", action="rate_limit.exceeded", target_type=scope, target_id=identifier[:80], metadata_json={"scope": scope, "identifier": identifier}, ) )