Files
drivers_bot/app/services/rate_limit.py
2026-05-18 18:17:53 +09:00

201 lines
5.9 KiB
Python

from __future__ import annotations
import time
from collections import defaultdict, deque
from collections.abc import Hashable
from fastapi import HTTPException, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.models.user import User
BucketKey = tuple[str, Hashable]
_buckets: dict[BucketKey, deque[float]] = defaultdict(deque)
_redis_client = None
def reset_rate_limit_state() -> None:
_buckets.clear()
async def check_rate_limit(
*,
scope: str,
limit: int,
window_seconds: int,
request: Request | None = None,
user: User | None = None,
session: AsyncSession | None = None,
) -> None:
identifiers: list[Hashable] = []
if user is not None:
identifiers.append(f"user:{user.id}")
identifiers.append(f"telegram:{user.telegram_id}")
if request is not None and request.client is not None:
identifiers.append(f"ip:{request.client.host}")
if not identifiers:
identifiers.append("anonymous")
if settings.redis_url:
allowed = await check_redis_rate_limit(scope, identifiers, limit, window_seconds)
if not allowed:
await log_rate_limit_event(
session,
scope=scope,
identifier="redis",
user=user,
request=request,
)
raise_rate_limit(scope, window_seconds)
return
now = time.monotonic()
for identifier in identifiers:
key = (scope, identifier)
bucket = _buckets[key]
while bucket and now - bucket[0] > window_seconds:
bucket.popleft()
if len(bucket) >= limit:
await log_rate_limit_event(
session,
scope=scope,
identifier=str(identifier),
user=user,
request=request,
)
raise_rate_limit(scope, window_seconds)
for identifier in identifiers:
_buckets[(scope, identifier)].append(now)
def raise_rate_limit(scope: str, window_seconds: int) -> None:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail={
"code": "rate_limit_exceeded",
"message": "Слишком много запросов. Попробуйте чуть позже.",
"scope": scope,
"retry_after_seconds": window_seconds,
},
)
async def get_redis_client():
global _redis_client
if _redis_client is not None:
return _redis_client
try:
from redis.asyncio import Redis
except ImportError:
return None
_redis_client = Redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
return _redis_client
async def check_redis_rate_limit(
scope: str,
identifiers: list[Hashable],
limit: int,
window_seconds: int,
) -> bool:
client = await get_redis_client()
if client is None:
return True
now_window = int(time.time() // window_seconds)
keys = [f"rl:{scope}:{identifier}:{now_window}" for identifier in identifiers]
pipe = client.pipeline()
for key in keys:
pipe.incr(key)
pipe.expire(key, window_seconds * 2)
results = await pipe.execute()
counts = [int(results[index]) for index in range(0, len(results), 2)]
return all(count <= limit for count in counts)
async def log_rate_limit_event(
session: AsyncSession | None,
*,
scope: str,
identifier: str,
user: User | None = None,
request: Request | None = None,
) -> None:
client_host = request.client.host if request and request.client else None
user_agent = request.headers.get("user-agent") if request else None
metadata = {
"scope": scope,
"identifier": identifier,
"telegram_id": user.telegram_id if user else None,
"user_id": user.id if user else None,
"ip": client_host,
}
if session is None:
from app.db.session import async_session_factory
async with async_session_factory() as event_session:
await persist_rate_limit_event(
event_session,
scope=scope,
identifier=identifier,
user=user,
client_host=client_host,
user_agent=user_agent,
metadata=metadata,
)
return
await persist_rate_limit_event(
session,
scope=scope,
identifier=identifier,
user=user,
client_host=client_host,
user_agent=user_agent,
metadata=metadata,
)
async def persist_rate_limit_event(
event_session: AsyncSession,
*,
scope: str,
identifier: str,
user: User | None,
client_host: str | None,
user_agent: str | None,
metadata: dict,
) -> None:
from app.models.car import AuditLog
from app.services.admin_notifications import create_admin_notification
try:
event_session.add(
AuditLog(
actor_user_id=user.id if user else None,
actor_role=user.platform_role if user else "system",
action="rate_limit.exceeded",
target_type=scope,
target_id=identifier[:80],
metadata_json=metadata,
ip=client_host,
user_agent=user_agent[:256] if user_agent else None,
)
)
await create_admin_notification(
event_session,
event_type="rate_limit_exceeded",
title="Rate limit exceeded",
body=f"Scope: {scope}\nIdentifier: {identifier}",
entity_type="user" if user else "system",
entity_id=user.id if user else scope,
severity="warning",
idempotency_key=f"rate_limit:{scope}:{identifier}:{int(time.time() // max(60, 1))}",
metadata=metadata,
)
await event_session.commit()
except Exception:
await event_session.rollback()