201 lines
5.9 KiB
Python
201 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import time
|
|
from collections import defaultdict, deque
|
|
from collections.abc import Hashable
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.models.user import User
|
|
|
|
BucketKey = tuple[str, Hashable]
|
|
|
|
_buckets: dict[BucketKey, deque[float]] = defaultdict(deque)
|
|
_redis_client = None
|
|
|
|
|
|
def reset_rate_limit_state() -> None:
|
|
_buckets.clear()
|
|
|
|
|
|
async def check_rate_limit(
|
|
*,
|
|
scope: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
request: Request | None = None,
|
|
user: User | None = None,
|
|
session: AsyncSession | None = None,
|
|
) -> None:
|
|
identifiers: list[Hashable] = []
|
|
if user is not None:
|
|
identifiers.append(f"user:{user.id}")
|
|
identifiers.append(f"telegram:{user.telegram_id}")
|
|
if request is not None and request.client is not None:
|
|
identifiers.append(f"ip:{request.client.host}")
|
|
if not identifiers:
|
|
identifiers.append("anonymous")
|
|
|
|
if settings.redis_url:
|
|
allowed = await check_redis_rate_limit(scope, identifiers, limit, window_seconds)
|
|
if not allowed:
|
|
await log_rate_limit_event(
|
|
session,
|
|
scope=scope,
|
|
identifier="redis",
|
|
user=user,
|
|
request=request,
|
|
)
|
|
raise_rate_limit(scope, window_seconds)
|
|
return
|
|
|
|
now = time.monotonic()
|
|
for identifier in identifiers:
|
|
key = (scope, identifier)
|
|
bucket = _buckets[key]
|
|
while bucket and now - bucket[0] > window_seconds:
|
|
bucket.popleft()
|
|
if len(bucket) >= limit:
|
|
await log_rate_limit_event(
|
|
session,
|
|
scope=scope,
|
|
identifier=str(identifier),
|
|
user=user,
|
|
request=request,
|
|
)
|
|
raise_rate_limit(scope, window_seconds)
|
|
for identifier in identifiers:
|
|
_buckets[(scope, identifier)].append(now)
|
|
|
|
|
|
def raise_rate_limit(scope: str, window_seconds: int) -> None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail={
|
|
"code": "rate_limit_exceeded",
|
|
"message": "Слишком много запросов. Попробуйте чуть позже.",
|
|
"scope": scope,
|
|
"retry_after_seconds": window_seconds,
|
|
},
|
|
)
|
|
|
|
|
|
async def get_redis_client():
|
|
global _redis_client
|
|
if _redis_client is not None:
|
|
return _redis_client
|
|
try:
|
|
from redis.asyncio import Redis
|
|
except ImportError:
|
|
return None
|
|
_redis_client = Redis.from_url(settings.redis_url, encoding="utf-8", decode_responses=True)
|
|
return _redis_client
|
|
|
|
|
|
async def check_redis_rate_limit(
|
|
scope: str,
|
|
identifiers: list[Hashable],
|
|
limit: int,
|
|
window_seconds: int,
|
|
) -> bool:
|
|
client = await get_redis_client()
|
|
if client is None:
|
|
return True
|
|
now_window = int(time.time() // window_seconds)
|
|
keys = [f"rl:{scope}:{identifier}:{now_window}" for identifier in identifiers]
|
|
pipe = client.pipeline()
|
|
for key in keys:
|
|
pipe.incr(key)
|
|
pipe.expire(key, window_seconds * 2)
|
|
results = await pipe.execute()
|
|
counts = [int(results[index]) for index in range(0, len(results), 2)]
|
|
return all(count <= limit for count in counts)
|
|
|
|
|
|
async def log_rate_limit_event(
|
|
session: AsyncSession | None,
|
|
*,
|
|
scope: str,
|
|
identifier: str,
|
|
user: User | None = None,
|
|
request: Request | None = None,
|
|
) -> None:
|
|
client_host = request.client.host if request and request.client else None
|
|
user_agent = request.headers.get("user-agent") if request else None
|
|
metadata = {
|
|
"scope": scope,
|
|
"identifier": identifier,
|
|
"telegram_id": user.telegram_id if user else None,
|
|
"user_id": user.id if user else None,
|
|
"ip": client_host,
|
|
}
|
|
|
|
if session is None:
|
|
from app.db.session import async_session_factory
|
|
|
|
async with async_session_factory() as event_session:
|
|
await persist_rate_limit_event(
|
|
event_session,
|
|
scope=scope,
|
|
identifier=identifier,
|
|
user=user,
|
|
client_host=client_host,
|
|
user_agent=user_agent,
|
|
metadata=metadata,
|
|
)
|
|
return
|
|
|
|
await persist_rate_limit_event(
|
|
session,
|
|
scope=scope,
|
|
identifier=identifier,
|
|
user=user,
|
|
client_host=client_host,
|
|
user_agent=user_agent,
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
async def persist_rate_limit_event(
|
|
event_session: AsyncSession,
|
|
*,
|
|
scope: str,
|
|
identifier: str,
|
|
user: User | None,
|
|
client_host: str | None,
|
|
user_agent: str | None,
|
|
metadata: dict,
|
|
) -> None:
|
|
from app.models.car import AuditLog
|
|
from app.services.admin_notifications import create_admin_notification
|
|
|
|
try:
|
|
event_session.add(
|
|
AuditLog(
|
|
actor_user_id=user.id if user else None,
|
|
actor_role=user.platform_role if user else "system",
|
|
action="rate_limit.exceeded",
|
|
target_type=scope,
|
|
target_id=identifier[:80],
|
|
metadata_json=metadata,
|
|
ip=client_host,
|
|
user_agent=user_agent[:256] if user_agent else None,
|
|
)
|
|
)
|
|
await create_admin_notification(
|
|
event_session,
|
|
event_type="rate_limit_exceeded",
|
|
title="Rate limit exceeded",
|
|
body=f"Scope: {scope}\nIdentifier: {identifier}",
|
|
entity_type="user" if user else "system",
|
|
entity_id=user.id if user else scope,
|
|
severity="warning",
|
|
idempotency_key=f"rate_limit:{scope}:{identifier}:{int(time.time() // max(60, 1))}",
|
|
metadata=metadata,
|
|
)
|
|
await event_session.commit()
|
|
except Exception:
|
|
await event_session.rollback()
|