209 lines
7.5 KiB
Python
209 lines
7.5 KiB
Python
import re
|
||
from decimal import Decimal
|
||
|
||
from fastapi import APIRouter, Depends, File, UploadFile
|
||
from pydantic import BaseModel
|
||
|
||
from app.api.deps import get_current_telegram_user
|
||
from app.models.user import User
|
||
from app.services.ocr_provider import get_ocr_provider
|
||
|
||
router = APIRouter(prefix="/ocr", tags=["ocr"])
|
||
|
||
|
||
class ReceiptSuggestion(BaseModel):
|
||
total_cost: Decimal | None = None
|
||
liters: Decimal | None = None
|
||
price_per_liter: Decimal | None = None
|
||
station: str | None = None
|
||
confidence: float
|
||
message: str
|
||
|
||
|
||
class OCRCandidateRead(BaseModel):
|
||
type: str
|
||
value: str
|
||
confidence: float
|
||
|
||
|
||
class OCRResultRead(BaseModel):
|
||
recognized_text: str
|
||
candidates: list[OCRCandidateRead]
|
||
provider: str = "heuristic"
|
||
|
||
|
||
@router.post("/parse-text-receipt", response_model=ReceiptSuggestion)
|
||
async def parse_text_receipt(
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
) -> ReceiptSuggestion:
|
||
content = await file.read()
|
||
content_type = (file.content_type or "").lower()
|
||
if content_type.startswith("image/") or content_type == "application/pdf":
|
||
result = await get_ocr_provider().recognize(content, file.filename)
|
||
if not result.recognized_text:
|
||
return ReceiptSuggestion(
|
||
confidence=0,
|
||
message="Не удалось уверенно распознать чек. Открылся ручной ввод: проверьте дату, сумму, литры и цену.",
|
||
)
|
||
return parse_receipt_text(result.recognized_text)
|
||
text = " ".join(
|
||
[
|
||
file.filename or "",
|
||
content.decode("utf-8", errors="ignore"),
|
||
]
|
||
)
|
||
return parse_receipt_text(text)
|
||
|
||
|
||
def parse_receipt_text(text: str) -> ReceiptSuggestion:
|
||
normalized = text.replace("\xa0", " ").replace(",", ".")
|
||
compact = re.sub(r"\s+", " ", normalized).strip()
|
||
numbers = [Decimal(item) for item in re.findall(r"\d+(?:\.\d+)?", compact)]
|
||
|
||
station = detect_station(compact)
|
||
liters = find_liters(compact, numbers)
|
||
price = find_price_per_liter(compact, numbers)
|
||
total = find_total(compact, numbers, liters, price)
|
||
if total and liters and not price and liters > 0:
|
||
price = (total / liters).quantize(Decimal("0.01"))
|
||
if liters and price and not total:
|
||
total = (liters * price).quantize(Decimal("0.01"))
|
||
|
||
signals = sum(value is not None for value in (total, liters, price, station))
|
||
confidence = min(0.88, 0.18 + signals * 0.17 + min(len(numbers), 12) * 0.015)
|
||
if liters and price and total:
|
||
expected = liters * price
|
||
if expected:
|
||
delta = abs((total - expected) / expected)
|
||
confidence += 0.1 if delta <= Decimal("0.08") else -0.08
|
||
confidence = max(0, min(float(confidence), 0.95))
|
||
|
||
return ReceiptSuggestion(
|
||
total_cost=total,
|
||
liters=liters,
|
||
price_per_liter=price,
|
||
station=station,
|
||
confidence=round(confidence, 2) if numbers else 0,
|
||
message=(
|
||
"Разобрал текст чека и заполнил форму. Проверь значения перед сохранением."
|
||
if numbers
|
||
else "Не удалось разобрать текст чека. Загрузите текстовый чек или заполните поля вручную."
|
||
),
|
||
)
|
||
|
||
|
||
@router.post("/fuel-receipt", response_model=ReceiptSuggestion, deprecated=True)
|
||
async def scan_fuel_receipt(
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
) -> ReceiptSuggestion:
|
||
return await parse_text_receipt(file, current_user)
|
||
|
||
|
||
@router.post("/license-plate", response_model=OCRResultRead)
|
||
async def recognize_license_plate(
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
) -> OCRResultRead:
|
||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "license_plate"],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
@router.post("/vin", response_model=OCRResultRead)
|
||
async def recognize_vin(
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
) -> OCRResultRead:
|
||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "vin"],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
@router.post("/service-document", response_model=OCRResultRead)
|
||
async def recognize_service_document(
|
||
file: UploadFile = File(...),
|
||
current_user: User = Depends(get_current_telegram_user),
|
||
) -> OCRResultRead:
|
||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||
return OCRResultRead(
|
||
recognized_text=result.recognized_text,
|
||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates],
|
||
provider=result.provider,
|
||
)
|
||
|
||
|
||
def detect_station(text: str) -> str | None:
|
||
stations = {
|
||
"shell": "Shell",
|
||
"lukoil": "Lukoil",
|
||
"лукойл": "Lukoil",
|
||
"gazprom": "Gazprom",
|
||
"газпром": "Gazprom",
|
||
"rosneft": "Rosneft",
|
||
"роснефть": "Rosneft",
|
||
"neste": "Neste",
|
||
}
|
||
lower = text.lower()
|
||
for needle, name in stations.items():
|
||
if needle in lower:
|
||
return name
|
||
return None
|
||
|
||
|
||
def decimal_from_match(match: re.Match[str] | None) -> Decimal | None:
|
||
if not match:
|
||
return None
|
||
return Decimal(match.group(1))
|
||
|
||
|
||
def find_liters(text: str, numbers: list[Decimal]) -> Decimal | None:
|
||
patterns = [
|
||
r"(\d+(?:\.\d+)?)\s*(?:l|литр|литра|литров|л)\b",
|
||
r"(?:volume|qty|кол-?во|количество|объем)\D{0,12}(\d+(?:\.\d+)?)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and Decimal("3") <= value <= Decimal("160"):
|
||
return value
|
||
return next((item for item in numbers if Decimal("5") <= item <= Decimal("120")), None)
|
||
|
||
|
||
def find_price_per_liter(text: str, numbers: list[Decimal]) -> Decimal | None:
|
||
patterns = [
|
||
r"(\d+(?:\.\d+)?)\s*(?:/|за)\s*(?:l|литр|л)\b",
|
||
r"(?:price|цена|ppu|руб/л|₽/л)\D{0,12}(\d+(?:\.\d+)?)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and Decimal("10") <= value <= Decimal("500"):
|
||
return value
|
||
candidates = [item for item in numbers if Decimal("10") <= item <= Decimal("500")]
|
||
return candidates[-1] if candidates else None
|
||
|
||
|
||
def find_total(
|
||
text: str,
|
||
numbers: list[Decimal],
|
||
liters: Decimal | None,
|
||
price: Decimal | None,
|
||
) -> Decimal | None:
|
||
patterns = [
|
||
r"(?:total|sum|amount|итого|сумма|к\s*оплате)\D{0,16}(\d+(?:\.\d+)?)",
|
||
r"(\d+(?:\.\d+)?)\s*(?:rub|₽|руб|krw|₩)",
|
||
]
|
||
for pattern in patterns:
|
||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||
if value and value > Decimal("50"):
|
||
return value
|
||
ignored = {value for value in (liters, price) if value is not None}
|
||
candidates = [item for item in numbers if item > Decimal("50") and item not in ignored]
|
||
return max(candidates) if candidates else None
|