Files
drivers_bot/app/api/ocr.py
VPN SaaS Dev 99bc9aa6a1
Some checks failed
ci / test (push) Has been cancelled
complete admin notifications data explorer
2026-05-19 19:02:16 +09:00

507 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import time
from datetime import date
from decimal import Decimal
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_telegram_user
from app.db.session import get_session
from app.models.car import OCRResult
from app.models.user import User
from app.services.admin_notifications import create_admin_notification
from app.services.ocr_provider import OcrResult, get_ocr_provider
from app.services.rate_limit import check_rate_limit
from app.services.uploads import SAFE_IMAGE_TYPES, SAFE_TEXT_TYPES, validate_upload
router = APIRouter(prefix="/ocr", tags=["ocr"])
MAX_OCR_FILE_BYTES = 8 * 1024 * 1024
class ReceiptSuggestion(BaseModel):
entry_date: date | None = None
total_cost: Decimal | None = None
liters: Decimal | None = None
price_per_liter: Decimal | None = None
station: str | None = None
category: str | None = None
confidence: float
message: str
class OCRCandidateRead(BaseModel):
type: str
value: str
confidence: float
class OCRResultRead(BaseModel):
recognized_text: str
candidates: list[OCRCandidateRead]
provider: str = "heuristic"
def ocr_candidates_json(result: OcrResult | None) -> list[dict] | None:
if result is None:
return None
return [
{"type": candidate.type, "value": candidate.value, "confidence": candidate.confidence}
for candidate in result.candidates
]
def ocr_confidence(result: OcrResult | None) -> Decimal | None:
if result is None or not result.candidates:
return None
return Decimal(str(round(max(candidate.confidence for candidate in result.candidates), 4)))
async def save_ocr_result(
session: AsyncSession,
*,
current_user: User,
scope: str,
filename: str | None,
content_type: str | None,
status: str,
result: OcrResult | None = None,
recognized_text: str | None = None,
provider: str | None = None,
error: str | None = None,
) -> OCRResult:
record = OCRResult(
user_id=current_user.id,
scope=scope,
filename=filename,
content_type=content_type,
status=status,
provider=result.provider if result is not None else provider,
confidence=ocr_confidence(result),
recognized_text=result.recognized_text if result is not None else recognized_text,
candidates_json=ocr_candidates_json(result),
error=error,
)
session.add(record)
await session.flush()
return record
async def validate_ocr_upload(
*,
session: AsyncSession,
current_user: User,
content: bytes,
filename: str | None,
content_type: str | None,
) -> str:
try:
return validate_upload(
content=content,
filename=filename,
content_type=content_type,
max_bytes=MAX_OCR_FILE_BYTES,
allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES,
)
except HTTPException as exc:
await save_ocr_result(
session,
current_user=current_user,
scope="upload_validation",
filename=filename,
content_type=content_type,
status="blocked",
error=str(exc.detail),
)
await create_admin_notification(
session,
event_type="upload_blocked",
title="Upload blocked",
body=f"OCR upload blocked: {filename or '-'}\nReason: {exc.detail}",
entity_type="user",
entity_id=current_user.id,
severity="warning",
idempotency_key=(
f"upload_blocked:{current_user.id}:{filename or 'upload'}:{exc.status_code}:"
f"{int(time.time() // 60)}"
),
metadata={
"filename": filename,
"content_type": content_type,
"status_code": exc.status_code,
"detail": exc.detail,
},
)
await session.commit()
raise
async def recognize_with_alert(
*,
session: AsyncSession,
current_user: User,
content: bytes,
filename: str | None,
scope: str,
):
try:
return await get_ocr_provider().recognize(content, filename)
except Exception as exc: # noqa: BLE001 - OCR must fail gracefully and alert admins
await save_ocr_result(
session,
current_user=current_user,
scope=scope,
filename=filename,
content_type=None,
status="failed",
error=type(exc).__name__,
)
await create_admin_notification(
session,
event_type="ocr_failed",
title="OCR provider failed",
body=f"Scope: {scope}\nFile: {filename or '-'}\nError: {type(exc).__name__}",
entity_type="user",
entity_id=current_user.id,
severity="error",
idempotency_key=f"ocr_failed:{scope}:{current_user.id}:{int(time.time() // 60)}",
metadata={"scope": scope, "filename": filename, "error_type": type(exc).__name__},
)
await session.commit()
return None
@router.post("/parse-text-receipt", response_model=ReceiptSuggestion)
async def parse_text_receipt(
request: Request,
file: UploadFile = File(...),
current_user: User = Depends(get_current_telegram_user),
session: AsyncSession = Depends(get_session),
) -> ReceiptSuggestion:
await check_rate_limit(scope="ocr", limit=10, window_seconds=60, request=request, user=current_user, session=session)
content = await file.read()
await validate_ocr_upload(
session=session,
current_user=current_user,
filename=file.filename,
content_type=file.content_type,
content=content,
)
content_type = (file.content_type or "").lower()
if content_type.startswith("image/") or content_type == "application/pdf":
result = await recognize_with_alert(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
scope="parse_text_receipt",
)
if not result or not result.recognized_text:
if result is not None:
await save_ocr_result(
session,
current_user=current_user,
scope="parse_text_receipt",
filename=file.filename,
content_type=file.content_type,
status="preview",
result=result,
)
await session.commit()
return ReceiptSuggestion(
confidence=0,
message="Не удалось уверенно распознать чек. Открылся ручной ввод: проверьте дату, сумму, литры и цену.",
)
await save_ocr_result(
session,
current_user=current_user,
scope="parse_text_receipt",
filename=file.filename,
content_type=file.content_type,
status="preview",
result=result,
)
await session.commit()
return parse_receipt_text(result.recognized_text)
text = " ".join(
[
file.filename or "",
content.decode("utf-8", errors="ignore"),
]
)
await save_ocr_result(
session,
current_user=current_user,
scope="parse_text_receipt",
filename=file.filename,
content_type=file.content_type,
status="preview",
recognized_text=text,
provider="text",
)
await session.commit()
return parse_receipt_text(text)
def parse_receipt_text(text: str) -> ReceiptSuggestion:
normalized = text.replace("\xa0", " ").replace(",", ".")
compact = re.sub(r"\s+", " ", normalized).strip()
numbers = [Decimal(item) for item in re.findall(r"\d+(?:\.\d+)?", compact)]
station = detect_station(compact)
entry_date = detect_date(compact)
liters = find_liters(compact, numbers)
price = find_price_per_liter(compact, numbers)
total = find_total(compact, numbers, liters, price)
if total and liters and not price and liters > 0:
price = (total / liters).quantize(Decimal("0.01"))
if liters and price and not total:
total = (liters * price).quantize(Decimal("0.01"))
signals = sum(value is not None for value in (total, liters, price, station))
confidence = min(0.88, 0.18 + signals * 0.17 + min(len(numbers), 12) * 0.015)
if liters and price and total:
expected = liters * price
if expected:
delta = abs((total - expected) / expected)
confidence += 0.1 if delta <= Decimal("0.08") else -0.08
confidence = max(0, min(float(confidence), 0.95))
return ReceiptSuggestion(
entry_date=entry_date,
total_cost=total,
liters=liters,
price_per_liter=price,
station=station,
category="fuel" if liters or price else None,
confidence=round(confidence, 2) if numbers else 0,
message=(
"Разобрал текст чека и заполнил форму. Проверь значения перед сохранением."
if numbers
else "Не удалось разобрать текст чека. Загрузите текстовый чек или заполните поля вручную."
),
)
@router.post("/fuel-receipt", response_model=ReceiptSuggestion, deprecated=True)
async def scan_fuel_receipt(
request: Request,
file: UploadFile = File(...),
current_user: User = Depends(get_current_telegram_user),
session: AsyncSession = Depends(get_session),
) -> ReceiptSuggestion:
return await parse_text_receipt(request, file, current_user, session)
@router.post("/license-plate", response_model=OCRResultRead)
async def recognize_license_plate(
request: Request,
file: UploadFile = File(...),
current_user: User = Depends(get_current_telegram_user),
session: AsyncSession = Depends(get_session),
) -> OCRResultRead:
await check_rate_limit(scope="ocr_license_plate", limit=8, window_seconds=60, request=request, user=current_user, session=session)
content = await file.read()
await validate_ocr_upload(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
content_type=file.content_type,
)
result = await recognize_with_alert(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
scope="license_plate",
)
if result is None:
return OCRResultRead(recognized_text="", candidates=[], provider="error")
await save_ocr_result(
session,
current_user=current_user,
scope="license_plate",
filename=file.filename,
content_type=file.content_type,
status="preview",
result=result,
)
await session.commit()
return OCRResultRead(
recognized_text=result.recognized_text,
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "license_plate"],
provider=result.provider,
)
@router.post("/vin", response_model=OCRResultRead)
async def recognize_vin(
request: Request,
file: UploadFile = File(...),
current_user: User = Depends(get_current_telegram_user),
session: AsyncSession = Depends(get_session),
) -> OCRResultRead:
await check_rate_limit(scope="ocr_vin", limit=8, window_seconds=60, request=request, user=current_user, session=session)
content = await file.read()
await validate_ocr_upload(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
content_type=file.content_type,
)
result = await recognize_with_alert(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
scope="vin",
)
if result is None:
return OCRResultRead(recognized_text="", candidates=[], provider="error")
await save_ocr_result(
session,
current_user=current_user,
scope="vin",
filename=file.filename,
content_type=file.content_type,
status="preview",
result=result,
)
await session.commit()
return OCRResultRead(
recognized_text=result.recognized_text,
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "vin"],
provider=result.provider,
)
@router.post("/service-document", response_model=OCRResultRead)
async def recognize_service_document(
request: Request,
file: UploadFile = File(...),
current_user: User = Depends(get_current_telegram_user),
session: AsyncSession = Depends(get_session),
) -> OCRResultRead:
await check_rate_limit(scope="ocr_service_document", limit=8, window_seconds=60, request=request, user=current_user, session=session)
content = await file.read()
await validate_ocr_upload(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
content_type=file.content_type,
)
result = await recognize_with_alert(
session=session,
current_user=current_user,
content=content,
filename=file.filename,
scope="service_document",
)
if result is None:
return OCRResultRead(recognized_text="", candidates=[], provider="error")
await save_ocr_result(
session,
current_user=current_user,
scope="service_document",
filename=file.filename,
content_type=file.content_type,
status="preview",
result=result,
)
await session.commit()
return OCRResultRead(
recognized_text=result.recognized_text,
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates],
provider=result.provider,
)
def detect_station(text: str) -> str | None:
stations = {
"shell": "Shell",
"lukoil": "Lukoil",
"лукойл": "Lukoil",
"gazprom": "Gazprom",
"газпром": "Gazprom",
"rosneft": "Rosneft",
"роснефть": "Rosneft",
"neste": "Neste",
}
lower = text.lower()
for needle, name in stations.items():
if needle in lower:
return name
return None
def detect_date(text: str) -> date | None:
for pattern in (
r"\b(\d{4})[-/.](\d{1,2})[-/.](\d{1,2})\b",
r"\b(\d{1,2})[-/.](\d{1,2})[-/.](\d{4})\b",
):
match = re.search(pattern, text)
if not match:
continue
first, second, third = [int(item) for item in match.groups()]
try:
if first > 1900:
return date(first, second, third)
return date(third, second, first)
except ValueError:
continue
return None
def decimal_from_match(match: re.Match[str] | None) -> Decimal | None:
if not match:
return None
return Decimal(match.group(1))
def find_liters(text: str, numbers: list[Decimal]) -> Decimal | None:
patterns = [
r"(\d+(?:\.\d+)?)\s*(?:l|литр|литра|литров|л)\b",
r"(?:volume|qty|кол-?во|количество|объем)\D{0,12}(\d+(?:\.\d+)?)",
]
for pattern in patterns:
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
if value and Decimal("3") <= value <= Decimal("160"):
return value
return next((item for item in numbers if Decimal("5") <= item <= Decimal("120")), None)
def find_price_per_liter(text: str, numbers: list[Decimal]) -> Decimal | None:
patterns = [
r"(\d+(?:\.\d+)?)\s*(?:/|за)\s*(?:l|литр|л)\b",
r"(?:price|цена|ppu|руб/л|₽/л)\D{0,12}(\d+(?:\.\d+)?)",
]
for pattern in patterns:
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
if value and Decimal("0.1") <= value <= Decimal("500"):
return value
candidates = [item for item in numbers if Decimal("0.1") <= item <= Decimal("500")]
return candidates[-1] if candidates else None
def find_total(
text: str,
numbers: list[Decimal],
liters: Decimal | None,
price: Decimal | None,
) -> Decimal | None:
patterns = [
r"(?:total|sum|amount|итого|сумма|к\s*оплате)\D{0,16}(\d+(?:\.\d+)?)",
r"(\d+(?:\.\d+)?)\s*(?:rub|₽|руб|krw|₩)",
]
for pattern in patterns:
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
if value and value > Decimal("50"):
return value
ignored = {value for value in (liters, price) if value is not None}
candidates = [item for item in numbers if item > Decimal("50") and item not in ignored]
return max(candidates) if candidates else None