This commit is contained in:
@@ -1,21 +1,29 @@
|
||||
import re
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
|
||||
from fastapi import APIRouter, Depends, File, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_telegram_user
|
||||
from app.db.session import get_session
|
||||
from app.models.user import User
|
||||
from app.services.ocr_provider import get_ocr_provider
|
||||
from app.services.rate_limit import check_rate_limit
|
||||
from app.services.uploads import SAFE_IMAGE_TYPES, SAFE_TEXT_TYPES, validate_upload
|
||||
|
||||
router = APIRouter(prefix="/ocr", tags=["ocr"])
|
||||
MAX_OCR_FILE_BYTES = 8 * 1024 * 1024
|
||||
|
||||
|
||||
class ReceiptSuggestion(BaseModel):
|
||||
entry_date: date | None = None
|
||||
total_cost: Decimal | None = None
|
||||
liters: Decimal | None = None
|
||||
price_per_liter: Decimal | None = None
|
||||
station: str | None = None
|
||||
category: str | None = None
|
||||
confidence: float
|
||||
message: str
|
||||
|
||||
@@ -34,10 +42,20 @@ class OCRResultRead(BaseModel):
|
||||
|
||||
@router.post("/parse-text-receipt", response_model=ReceiptSuggestion)
|
||||
async def parse_text_receipt(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_telegram_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ReceiptSuggestion:
|
||||
await check_rate_limit(scope="ocr", limit=10, window_seconds=60, request=request, user=current_user, session=session)
|
||||
content = await file.read()
|
||||
validate_upload(
|
||||
content=content,
|
||||
filename=file.filename,
|
||||
content_type=file.content_type,
|
||||
max_bytes=MAX_OCR_FILE_BYTES,
|
||||
allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES,
|
||||
)
|
||||
content_type = (file.content_type or "").lower()
|
||||
if content_type.startswith("image/") or content_type == "application/pdf":
|
||||
result = await get_ocr_provider().recognize(content, file.filename)
|
||||
@@ -62,6 +80,7 @@ def parse_receipt_text(text: str) -> ReceiptSuggestion:
|
||||
numbers = [Decimal(item) for item in re.findall(r"\d+(?:\.\d+)?", compact)]
|
||||
|
||||
station = detect_station(compact)
|
||||
entry_date = detect_date(compact)
|
||||
liters = find_liters(compact, numbers)
|
||||
price = find_price_per_liter(compact, numbers)
|
||||
total = find_total(compact, numbers, liters, price)
|
||||
@@ -80,10 +99,12 @@ def parse_receipt_text(text: str) -> ReceiptSuggestion:
|
||||
confidence = max(0, min(float(confidence), 0.95))
|
||||
|
||||
return ReceiptSuggestion(
|
||||
entry_date=entry_date,
|
||||
total_cost=total,
|
||||
liters=liters,
|
||||
price_per_liter=price,
|
||||
station=station,
|
||||
category="fuel" if liters or price else None,
|
||||
confidence=round(confidence, 2) if numbers else 0,
|
||||
message=(
|
||||
"Разобрал текст чека и заполнил форму. Проверь значения перед сохранением."
|
||||
@@ -95,18 +116,25 @@ def parse_receipt_text(text: str) -> ReceiptSuggestion:
|
||||
|
||||
@router.post("/fuel-receipt", response_model=ReceiptSuggestion, deprecated=True)
|
||||
async def scan_fuel_receipt(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_telegram_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> ReceiptSuggestion:
|
||||
return await parse_text_receipt(file, current_user)
|
||||
return await parse_text_receipt(request, file, current_user, session)
|
||||
|
||||
|
||||
@router.post("/license-plate", response_model=OCRResultRead)
|
||||
async def recognize_license_plate(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_telegram_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> OCRResultRead:
|
||||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||||
await check_rate_limit(scope="ocr_license_plate", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||||
content = await file.read()
|
||||
validate_upload(content=content, filename=file.filename, content_type=file.content_type, max_bytes=MAX_OCR_FILE_BYTES, allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES)
|
||||
result = await get_ocr_provider().recognize(content, file.filename)
|
||||
return OCRResultRead(
|
||||
recognized_text=result.recognized_text,
|
||||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "license_plate"],
|
||||
@@ -116,10 +144,15 @@ async def recognize_license_plate(
|
||||
|
||||
@router.post("/vin", response_model=OCRResultRead)
|
||||
async def recognize_vin(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_telegram_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> OCRResultRead:
|
||||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||||
await check_rate_limit(scope="ocr_vin", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||||
content = await file.read()
|
||||
validate_upload(content=content, filename=file.filename, content_type=file.content_type, max_bytes=MAX_OCR_FILE_BYTES, allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES)
|
||||
result = await get_ocr_provider().recognize(content, file.filename)
|
||||
return OCRResultRead(
|
||||
recognized_text=result.recognized_text,
|
||||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates if item.type == "vin"],
|
||||
@@ -129,10 +162,15 @@ async def recognize_vin(
|
||||
|
||||
@router.post("/service-document", response_model=OCRResultRead)
|
||||
async def recognize_service_document(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_telegram_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> OCRResultRead:
|
||||
result = await get_ocr_provider().recognize(await file.read(), file.filename)
|
||||
await check_rate_limit(scope="ocr_service_document", limit=8, window_seconds=60, request=request, user=current_user, session=session)
|
||||
content = await file.read()
|
||||
validate_upload(content=content, filename=file.filename, content_type=file.content_type, max_bytes=MAX_OCR_FILE_BYTES, allowed_types=SAFE_IMAGE_TYPES | SAFE_TEXT_TYPES)
|
||||
result = await get_ocr_provider().recognize(content, file.filename)
|
||||
return OCRResultRead(
|
||||
recognized_text=result.recognized_text,
|
||||
candidates=[OCRCandidateRead(**item.__dict__) for item in result.candidates],
|
||||
@@ -158,6 +196,24 @@ def detect_station(text: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def detect_date(text: str) -> date | None:
|
||||
for pattern in (
|
||||
r"\b(\d{4})[-/.](\d{1,2})[-/.](\d{1,2})\b",
|
||||
r"\b(\d{1,2})[-/.](\d{1,2})[-/.](\d{4})\b",
|
||||
):
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
continue
|
||||
first, second, third = [int(item) for item in match.groups()]
|
||||
try:
|
||||
if first > 1900:
|
||||
return date(first, second, third)
|
||||
return date(third, second, first)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def decimal_from_match(match: re.Match[str] | None) -> Decimal | None:
|
||||
if not match:
|
||||
return None
|
||||
@@ -183,9 +239,9 @@ def find_price_per_liter(text: str, numbers: list[Decimal]) -> Decimal | None:
|
||||
]
|
||||
for pattern in patterns:
|
||||
value = decimal_from_match(re.search(pattern, text, re.IGNORECASE))
|
||||
if value and Decimal("10") <= value <= Decimal("500"):
|
||||
if value and Decimal("0.1") <= value <= Decimal("500"):
|
||||
return value
|
||||
candidates = [item for item in numbers if Decimal("10") <= item <= Decimal("500")]
|
||||
candidates = [item for item in numbers if Decimal("0.1") <= item <= Decimal("500")]
|
||||
return candidates[-1] if candidates else None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user