381 lines
13 KiB
Python
381 lines
13 KiB
Python
import re
|
||
import time
|
||
from datetime import date
|
||
from decimal import Decimal
|
||
|
||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_telegram_user
|
||
from app.db.session import get_session
|
||
from app.models.user import User
|
||
from app.services.admin_notifications import create_admin_notification
|
||
from app.services.ocr_provider import get_ocr_provider
|
||
from app.services.rate_limit import check_rate_limit
|
||
from app.services.uploads import SAFE_IMAGE_TYPES, SAFE_TEXT_TYPES, validate_upload
|
||
|
||
router = APIRouter(prefix="/ocr", tags=["ocr"])
|
||
MAX_OCR_FILE_BYTES = 8 * 1024 * 1024
|
||
|
||
|
||
class ReceiptSuggestion(BaseModel):
|
||
entry_date: date | None = None
|
||
total_cost: Decimal | None = None
|
||
liters: Decimal | None = None
|
||
price_per_liter: Decimal | None = None
|
||
station: str | None = None
|
||
category: str | None = None
|
||
confidence: float
|
||
message: str
|
||
|
||
|
||
class OCRCandidateRead(BaseModel):
|
||
type: str
|
||
value: str
|
||
confidence: float
|
||
|
||
|
||
class OCRResultRead(BaseModel):
|
||
recognized_text: str
|
||
candidates: list[OCRCandidateRead]
|
||
provider: str = "heuristic"
|
||
|
||
|
||
async def validate_ocr_upload(
|
||
*,
|
||
session: AsyncSession,
|
||
current_user: User,
|
||
content: bytes,
|
||
filename: str | None,
|
||
content_type: str | None,
|
||
) -> str:
|
||
try:
|
||
return validate_upload(
|
||
content=content,
|
||
filename=filename,
|
||
content_type=content_type,
|
||
max_bytes=MAX_OCR_FILE_BYTES,
|
||
allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES,
|
||
)
|
||
except HTTPException as exc:
|
||
await create_admin_notification(
|
||
session,
|
||
event_type="upload_blocked",
|
||
title="Upload blocked",
|
||
body=f"OCR upload blocked: {filename or '-'}\nReason: {exc.detail}",
|
||
entity_type="user",
|
||
entity_id=current_user.id,
|
||
severity="warning",
|
||
idempotency_key=(
|
||
f"upload_blocked:{current_user.id}:{filename or 'upload'}:{exc.status_code}:"
|
||
f"{int(time.time() // 60)}"
|
||
),
|
||
metadata={
|
||
"filename": filename,
|
||
"content_type": content_type,
|
||
"status_code": exc.status_code,
|
||
"detail": exc.detail,
|
||
},
|
||
)
|
||
await session.commit()
|
||
raise
|
||
|
||
|
||
async def recognize_with_alert(
|
||
*,
|
||
session: AsyncSession,
|
||
current_user: User,
|
||
content: bytes,
|
||
filename: str | None,
|
||
scope: str,
|
||
):
|
||
try:
|
||
return await get_ocr_provider().recognize(content, filename)
|
||
except Exception as exc: # noqa: BLE001 - OCR must fail gracefully and alert admins
|
||
await create_admin_notification(
|
||
session,
|
||
event_type="ocr_failed",
|
||
title="OCR provider failed",
|
||
body=f"Scope: {scope}\nFile: {filename or '-'}\nError: {type(exc).__name__}",
|
||
entity_type="user",
|
||
entity_id=current_user.id,
|
||
severity="error",
|
||
idempotency_key=f"ocr_failed:{scope}:{current_user.id}:{int(time.time() // 60)}",
|
||
metadata={"scope": scope, "filename": filename, "error_type": type(exc).__name__},
|
||
)
|
||
await session.commit()
|
||
return None
|
||
|
||
|
||
@router.post("/parse-text-receipt", response_model=ReceiptSuggestion)
|
||
async def parse_text_receipt(
|
||
request: Request,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
session: AsyncSession = Depends(get_session),
|
||
) -> ReceiptSuggestion:
|
||
await check_rate_limit(scope="ocr", limit=10, window_seconds=60, request=request, user=current_user, session=session)
|
||
content = await file.read()
|
||
await validate_ocr_upload(
|
||
session=session,
|
||
current_user=current_user,
|
||
filename=file.filename,
|
||
content_type=file.content_type,
|
||
content=content,
|
||
)
|
||
content_type = (file.content_type or "").lower()
|
||
if content_type.startswith("image/") or content_type == "application/pdf":
|
||
result = await recognize_with_alert(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
scope="parse_text_receipt",
|
||
)
|
||
if not result or not result.recognized_text:
|
||
return ReceiptSuggestion(
|
||
confidence=0,
|
||
message="Не удалось уверенно распознать чек. Открылся ручной ввод: проверьте дату, сумму, литры и цену.",
|
||
)
|
||
return parse_receipt_text(result.recognized_text)
|
||
text = " ".join(
|
||
[
|
||
file.filename or "",
|
||
content.decode("utf-8", errors="ignore"),
|
||
]
|
||
)
|
||
return parse_receipt_text(text)
|
||
|
||
|
||
def parse_receipt_text(text: str) -> ReceiptSuggestion:
|
||
normalized = text.replace("\xa0", " ").replace(",", ".")
|
||
compact = re.sub(r"\s+", " ", normalized).strip()
|
||
numbers = [Decimal(item) for item in re.findall(r"\d+(?:\.\d+)?", compact)]
|
||
|
||
station = detect_station(compact)
|
||
entry_date = detect_date(compact)
|
||
liters = find_liters(compact, numbers)
|
||
price = find_price_per_liter(compact, numbers)
|
||
total = find_total(compact, numbers, liters, price)
|
||
if total and liters and not price and liters > 0:
|
||
price = (total / liters).quantize(Decimal("0.01"))
|
||
if liters and price and not total:
|
||
total = (liters * price).quantize(Decimal("0.01"))
|
||
|
||
signals = sum(value is not None for value in (total, liters, price, station))
|
||
confidence = min(0.88, 0.18 + signals * 0.17 + min(len(numbers), 12) * 0.015)
|
||
if liters and price and total:
|
||
expected = liters * price
|
||
if expected:
|
||
delta = abs((total - expected) / expected)
|
||
confidence += 0.1 if delta <= Decimal("0.08") else -0.08
|
||
confidence = max(0, min(float(confidence), 0.95))
|
||
|
||
return ReceiptSuggestion(
|
||
entry_date=entry_date,
|
||
total_cost=total,
|
||
liters=liters,
|
||
price_per_liter=price,
|
||
station=station,
|
||
category="fuel" if liters or price else None,
|
||
confidence=round(confidence, 2) if numbers else 0,
|
||
message=(
|
||
"Разобрал текст чека и заполнил форму. Проверь значения перед сохранением."
|
||
if numbers
|
||
else "Не удалось разобрать текст чека. Загрузите текстовый чек или заполните поля вручную."
|
||
),
|
||
)
|
||
|
||
|
||
@router.post("/fuel-receipt", response_model=ReceiptSuggestion, deprecated=True)
|
||
async def scan_fuel_receipt(
|
||
request: Request,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
session: AsyncSession = Depends(get_session),
|
||
) -> ReceiptSuggestion:
|
||
return await parse_text_receipt(request, file, current_user, session)
|
||
|
||
|
||
@router.post("/license-plate", response_model=OCRResultRead)
|
||
async def recognize_license_plate(
|
||
request: Request,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
session: AsyncSession = Depends(get_session),
|
||
) -> OCRResultRead:
|
||
await check_rate_limit(scope="ocr_license_plate", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||
content = await file.read()
|
||
await validate_ocr_upload(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
content_type=file.content_type,
|
||
)
|
||
result = await recognize_with_alert(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
scope="license_plate",
|
||
)
|
||
if result is None:
|
||
return OCRResultRead(recognized_text="", candidates=[], provider="error")
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "license_plate"],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
@router.post("/vin", response_model=OCRResultRead)
|
||
async def recognize_vin(
|
||
request: Request,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
session: AsyncSession = Depends(get_session),
|
||
) -> OCRResultRead:
|
||
await check_rate_limit(scope="ocr_vin", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||
content = await file.read()
|
||
await validate_ocr_upload(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
content_type=file.content_type,
|
||
)
|
||
result = await recognize_with_alert(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
scope="vin",
|
||
)
|
||
if result is None:
|
||
return OCRResultRead(recognized_text="", candidates=[], provider="error")
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "vin"],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
@router.post("/service-document", response_model=OCRResultRead)
|
||
async def recognize_service_document(
|
||
request: Request,
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
session: AsyncSession = Depends(get_session),
|
||
) -> OCRResultRead:
|
||
await check_rate_limit(scope="ocr_service_document", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||
content = await file.read()
|
||
await validate_ocr_upload(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
content_type=file.content_type,
|
||
)
|
||
result = await recognize_with_alert(
|
||
session=session,
|
||
current_user=current_user,
|
||
content=content,
|
||
filename=file.filename,
|
||
scope="service_document",
|
||
)
|
||
if result is None:
|
||
return OCRResultRead(recognized_text="", candidates=[], provider="error")
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
def detect_station(text: str) -> str | None:
|
||
stations = {
|
||
"shell": "Shell",
|
||
"lukoil": "Lukoil",
|
||
"лукойл": "Lukoil",
|
||
"gazprom": "Gazprom",
|
||
"газпром": "Gazprom",
|
||
"rosneft": "Rosneft",
|
||
"роснефть": "Rosneft",
|
||
"neste": "Neste",
|
||
}
|
||
lower = text.lower()
|
||
for needle, name in stations.items():
|
||
if needle in lower:
|
||
return name
|
||
return None
|
||
|
||
|
||
def detect_date(text: str) -> date | None:
|
||
for pattern in (
|
||
r"\b(\d{4})[-/.](\d{1,2})[-/.](\d{1,2})\b",
|
||
r"\b(\d{1,2})[-/.](\d{1,2})[-/.](\d{4})\b",
|
||
):
|
||
match = re.search(pattern, text)
|
||
if not match:
|
||
continue
|
||
first, second, third = [int(item) for item in match.groups()]
|
||
try:
|
||
if first > 1900:
|
||
return date(first, second, third)
|
||
return date(third, second, first)
|
||
except ValueError:
|
||
continue
|
||
return None
|
||
|
||
|
||
def decimal_from_match(match: re.Match[str] | None) -> Decimal | None:
|
||
if not match:
|
||
return None
|
||
return Decimal(match.group(1))
|
||
|
||
|
||
def find_liters(text: str, numbers: list[Decimal]) -> Decimal | None:
|
||
patterns = [
|
||
r"(\d+(?:\.\d+)?)\s*(?:l|литр|литра|литров|л)\b",
|
||
r"(?:volume|qty|кол-?во|количество|объем)\D{0,12}(\d+(?:\.\d+)?)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and Decimal("3") <= value <= Decimal("160"):
|
||
return value
|
||
return next((item for item in numbers if Decimal("5") <= item <= Decimal("120")), None)
|
||
|
||
|
||
def find_price_per_liter(text: str, numbers: list[Decimal]) -> Decimal | None:
|
||
patterns = [
|
||
r"(\d+(?:\.\d+)?)\s*(?:/|за)\s*(?:l|литр|л)\b",
|
||
r"(?:price|цена|ppu|руб/л|₽/л)\D{0,12}(\d+(?:\.\d+)?)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and Decimal("0.1") <= value <= Decimal("500"):
|
||
return value
|
||
candidates = [item for item in numbers if Decimal("0.1") <= item <= Decimal("500")]
|
||
return candidates[-1] if candidates else None
|
||
|
||
|
||
def find_total(
|
||
text: str,
|
||
numbers: list[Decimal],
|
||
liters: Decimal | None,
|
||
price: Decimal | None,
|
||
) -> Decimal | None:
|
||
patterns = [
|
||
r"(?:total|sum|amount|итого|сумма|к\s*оплате)\D{0,16}(\d+(?:\.\d+)?)",
|
||
r"(\d+(?:\.\d+)?)\s*(?:rub|₽|руб|krw|₩)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and value > Decimal("50"):
|
||
return value
|
||
ignored = {value for value in (liters, price) if value is not None}
|
||
candidates = [item for item in numbers if item > Decimal("50") and item not in ignored]
|
||
return max(candidates) if candidates else None
|