import re import time from datetime import date from decimal import Decimal from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_telegram_user from app.db.session import get_session from app.models.user import User from app.services.admin_notifications import create_admin_notification from app.services.ocr_provider import get_ocr_provider from app.services.rate_limit import check_rate_limit from app.services.uploads import SAFE_IMAGE_TYPES, SAFE_TEXT_TYPES, validate_upload router = APIRouter(prefix="/ocr", tags=["ocr"]) MAX_OCR_FILE_BYTES = 8 * 1024 * 1024 class ReceiptSuggestion(BaseModel): entry_date: date | None = None total_cost: Decimal | None = None liters: Decimal | None = None price_per_liter: Decimal | None = None station: str | None = None category: str | None = None confidence: float message: str class OCRCandidateRead(BaseModel): type: str value: str confidence: float class OCRResultRead(BaseModel): recognized_text: str candidates: list[OCRCandidateRead] provider: str = "heuristic" async def validate_ocr_upload( *, session: AsyncSession, current_user: User, content: bytes, filename: str | None, content_type: str | None, ) -> str: try: return validate_upload( content=content, filename=filename, content_type=content_type, max_bytes=MAX_OCR_FILE_BYTES, allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES, ) except HTTPException as exc: await create_admin_notification( session, event_type="upload_blocked", title="Upload blocked", body=f"OCR upload blocked: {filename or '-'}\nReason: {exc.detail}", entity_type="user", entity_id=current_user.id, severity="warning", idempotency_key=( f"upload_blocked:{current_user.id}:{filename or 'upload'}:{exc.status_code}:" f"{int(time.time() // 60)}" ), metadata={ "filename": filename, "content_type": content_type, "status_code": exc.status_code, "detail": exc.detail, }, ) await session.commit() raise async def recognize_with_alert( *, session: AsyncSession, current_user: User, content: bytes, filename: str | None, scope: str, ): try: return await get_ocr_provider().recognize(content, filename) except Exception as exc: # noqa: BLE001 - OCR must fail gracefully and alert admins await create_admin_notification( session, event_type="ocr_failed", title="OCR provider failed", body=f"Scope: {scope}\nFile: {filename or '-'}\nError: {type(exc).__name__}", entity_type="user", entity_id=current_user.id, severity="error", idempotency_key=f"ocr_failed:{scope}:{current_user.id}:{int(time.time() // 60)}", metadata={"scope": scope, "filename": filename, "error_type": type(exc).__name__}, ) await session.commit() return None @router.post("/parse-text-receipt", response_model=ReceiptSuggestion) async def parse_text_receipt( request: Request, file: UploadFile = File(...), current_user: User = Depends(get_current_telegram_user), session: AsyncSession = Depends(get_session), ) -> ReceiptSuggestion: await check_rate_limit(scope="ocr", limit=10, window_seconds=60, request=request, user=current_user, session=session) content = await file.read() await validate_ocr_upload( session=session, current_user=current_user, filename=file.filename, content_type=file.content_type, content=content, ) content_type = (file.content_type or "").lower() if content_type.startswith("image/") or content_type == "application/pdf": result = await recognize_with_alert( session=session, current_user=current_user, content=content, filename=file.filename, scope="parse_text_receipt", ) if not result or not result.recognized_text: return ReceiptSuggestion( confidence=0, message="Не удалось уверенно распознать чек. Открылся ручной ввод: проверьте дату, сумму, литры и цену.", ) return parse_receipt_text(result.recognized_text) text = " ".join( [ file.filename or "", content.decode("utf-8", errors="ignore"), ] ) return parse_receipt_text(text) def parse_receipt_text(text: str) -> ReceiptSuggestion: normalized = text.replace("\xa0", " ").replace(",", ".") compact = re.sub(r"\s+", " ", normalized).strip() numbers = [Decimal(item) for item in re.findall(r"\d+(?:\.\d+)?", compact)] station = detect_station(compact) entry_date = detect_date(compact) liters = find_liters(compact, numbers) price = find_price_per_liter(compact, numbers) total = find_total(compact, numbers, liters, price) if total and liters and not price and liters > 0: price = (total / liters).quantize(Decimal("0.01")) if liters and price and not total: total = (liters * price).quantize(Decimal("0.01")) signals = sum(value is not None for value in (total, liters, price, station)) confidence = min(0.88, 0.18 + signals * 0.17 + min(len(numbers), 12) * 0.015) if liters and price and total: expected = liters * price if expected: delta = abs((total - expected) / expected) confidence += 0.1 if delta <= Decimal("0.08") else -0.08 confidence = max(0, min(float(confidence), 0.95)) return ReceiptSuggestion( entry_date=entry_date, total_cost=total, liters=liters, price_per_liter=price, station=station, category="fuel" if liters or price else None, confidence=round(confidence, 2) if numbers else 0, message=( "Разобрал текст чека и заполнил форму. Проверь значения перед сохранением." if numbers else "Не удалось разобрать текст чека. Загрузите текстовый чек или заполните поля вручную." ), ) @router.post("/fuel-receipt", response_model=ReceiptSuggestion, deprecated=True) async def scan_fuel_receipt( request: Request, file: UploadFile = File(...), current_user: User = Depends(get_current_telegram_user), session: AsyncSession = Depends(get_session), ) -> ReceiptSuggestion: return await parse_text_receipt(request, file, current_user, session) @router.post("/license-plate", response_model=OCRResultRead) async def recognize_license_plate( request: Request, file: UploadFile = File(...), current_user: User = Depends(get_current_telegram_user), session: AsyncSession = Depends(get_session), ) -> OCRResultRead: await check_rate_limit(scope="ocr_license_plate", limit=8, window_seconds=60, request=request, user=current_user, session=session) content = await file.read() await validate_ocr_upload( session=session, current_user=current_user, content=content, filename=file.filename, content_type=file.content_type, ) result = await recognize_with_alert( session=session, current_user=current_user, content=content, filename=file.filename, scope="license_plate", ) if result is None: return OCRResultRead(recognized_text="", candidates=[], provider="error") return OCRResultRead( recognized_text=result.recognized_text, candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "license_plate"], provider=result.provider, ) @router.post("/vin", response_model=OCRResultRead) async def recognize_vin( request: Request, file: UploadFile = File(...), current_user: User = Depends(get_current_telegram_user), session: AsyncSession = Depends(get_session), ) -> OCRResultRead: await check_rate_limit(scope="ocr_vin", limit=8, window_seconds=60, request=request, user=current_user, session=session) content = await file.read() await validate_ocr_upload( session=session, current_user=current_user, content=content, filename=file.filename, content_type=file.content_type, ) result = await recognize_with_alert( session=session, current_user=current_user, content=content, filename=file.filename, scope="vin", ) if result is None: return OCRResultRead(recognized_text="", candidates=[], provider="error") return OCRResultRead( recognized_text=result.recognized_text, candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "vin"], provider=result.provider, ) @router.post("/service-document", response_model=OCRResultRead) async def recognize_service_document( request: Request, file: UploadFile = File(...), current_user: User = Depends(get_current_telegram_user), session: AsyncSession = Depends(get_session), ) -> OCRResultRead: await check_rate_limit(scope="ocr_service_document", limit=8, window_seconds=60, request=request, user=current_user, session=session) content = await file.read() await validate_ocr_upload( session=session, current_user=current_user, content=content, filename=file.filename, content_type=file.content_type, ) result = await recognize_with_alert( session=session, current_user=current_user, content=content, filename=file.filename, scope="service_document", ) if result is None: return OCRResultRead(recognized_text="", candidates=[], provider="error") return OCRResultRead( recognized_text=result.recognized_text, candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates], provider=result.provider, ) def detect_station(text: str) -> str | None: stations = { "shell": "Shell", "lukoil": "Lukoil", "лукойл": "Lukoil", "gazprom": "Gazprom", "газпром": "Gazprom", "rosneft": "Rosneft", "роснефть": "Rosneft", "neste": "Neste", } lower = text.lower() for needle, name in stations.items(): if needle in lower: return name return None def detect_date(text: str) -> date | None: for pattern in ( r"\b(\d{4})[-/.](\d{1,2})[-/.](\d{1,2})\b", r"\b(\d{1,2})[-/.](\d{1,2})[-/.](\d{4})\b", ): match = re.search(pattern, text) if not match: continue first, second, third = [int(item) for item in match.groups()] try: if first > 1900: return date(first, second, third) return date(third, second, first) except ValueError: continue return None def decimal_from_match(match: re.Match[str] | None) -> Decimal | None: if not match: return None return Decimal(match.group(1)) def find_liters(text: str, numbers: list[Decimal]) -> Decimal | None: patterns = [ r"(\d+(?:\.\d+)?)\s*(?:l|литр|литра|литров|л)\b", r"(?:volume|qty|кол-?во|количество|объем)\D{0,12}(\d+(?:\.\d+)?)", ] for pattern in patterns: value = decimal_from_match(re.search(pattern, text, re.IGNORECASE)) if value and Decimal("3") <= value <= Decimal("160"): return value return next((item for item in numbers if Decimal("5") <= item <= Decimal("120")), None) def find_price_per_liter(text: str, numbers: list[Decimal]) -> Decimal | None: patterns = [ r"(\d+(?:\.\d+)?)\s*(?:/|за)\s*(?:l|литр|л)\b", r"(?:price|цена|ppu|руб/л|₽/л)\D{0,12}(\d+(?:\.\d+)?)", ] for pattern in patterns: value = decimal_from_match(re.search(pattern, text, re.IGNORECASE)) if value and Decimal("0.1") <= value <= Decimal("500"): return value candidates = [item for item in numbers if Decimal("0.1") <= item <= Decimal("500")] return candidates[-1] if candidates else None def find_total( text: str, numbers: list[Decimal], liters: Decimal | None, price: Decimal | None, ) -> Decimal | None: patterns = [ r"(?:total|sum|amount|итого|сумма|к\s*оплате)\D{0,16}(\d+(?:\.\d+)?)", r"(\d+(?:\.\d+)?)\s*(?:rub|₽|руб|krw|₩)", ] for pattern in patterns: value = decimal_from_match(re.search(pattern, text, re.IGNORECASE)) if value and value > Decimal("50"): return value ignored = {value for value in (liters, price) if value is not None} candidates = [item for item in numbers if item > Decimal("50") and item not in ignored] return max(candidates) if candidates else None