115 lines
3.7 KiB
Python
115 lines
3.7 KiB
Python
import asyncio
|
||
import re
|
||
from dataclasses import dataclass
|
||
from functools import lru_cache
|
||
from io import BytesIO
|
||
from typing import Protocol
|
||
|
||
from app.core.config import settings
|
||
from app.services.vehicle_identity import normalize_license_plate, validate_vin
|
||
|
||
|
||
@dataclass
|
||
class OcrCandidate:
|
||
type: str
|
||
value: str
|
||
confidence: float
|
||
|
||
|
||
@dataclass
|
||
class OcrResult:
|
||
recognized_text: str
|
||
candidates: list[OcrCandidate]
|
||
provider: str = "heuristic"
|
||
|
||
|
||
class OCRProvider(Protocol):
|
||
async def recognize(self, content: bytes, filename: str | None = None) -> OcrResult:
|
||
...
|
||
|
||
|
||
class TextHeuristicOCRProvider:
|
||
provider_name = "heuristic"
|
||
|
||
async def recognize(self, content: bytes, filename: str | None = None) -> OcrResult:
|
||
text = " ".join([filename or "", content.decode("utf-8", errors="ignore")])
|
||
return build_ocr_result(text, provider=self.provider_name, base_confidence=0.62)
|
||
|
||
|
||
class TesseractOCRProvider:
|
||
provider_name = "tesseract"
|
||
|
||
async def recognize(self, content: bytes, filename: str | None = None) -> OcrResult:
|
||
text = await asyncio.to_thread(self._recognize_sync, content)
|
||
if not text.strip():
|
||
fallback = await TextHeuristicOCRProvider().recognize(content, filename)
|
||
fallback.provider = self.provider_name
|
||
return fallback
|
||
return build_ocr_result(text, provider=self.provider_name, base_confidence=0.78)
|
||
|
||
def _recognize_sync(self, content: bytes) -> str:
|
||
try:
|
||
import pytesseract
|
||
from PIL import Image
|
||
except ImportError:
|
||
return ""
|
||
try:
|
||
image = Image.open(BytesIO(content))
|
||
except Exception:
|
||
return ""
|
||
try:
|
||
return pytesseract.image_to_string(image, lang=settings.ocr_languages)
|
||
except Exception:
|
||
return pytesseract.image_to_string(image)
|
||
|
||
|
||
class CompositeOCRProvider:
|
||
def __init__(self) -> None:
|
||
provider = settings.ocr_provider.lower()
|
||
self.primary: OCRProvider = (
|
||
TextHeuristicOCRProvider() if provider == "heuristic" else TesseractOCRProvider()
|
||
)
|
||
|
||
async def recognize(self, content: bytes, filename: str | None = None) -> OcrResult:
|
||
return await self.primary.recognize(content, filename)
|
||
|
||
|
||
def build_ocr_result(text: str, *, provider: str, base_confidence: float) -> OcrResult:
|
||
compact = re.sub(r"\s+", " ", text.replace("\xa0", " ")).strip()
|
||
candidates: list[OcrCandidate] = []
|
||
upper = compact.upper()
|
||
seen: set[tuple[str, str]] = set()
|
||
|
||
for raw in re.findall(r"\b[A-HJ-NPR-Z0-9]{17}\b", upper):
|
||
try:
|
||
value = validate_vin(raw) or raw
|
||
except ValueError:
|
||
continue
|
||
key = ("vin", value)
|
||
if key not in seen:
|
||
seen.add(key)
|
||
candidates.append(OcrCandidate(type="vin", value=value, confidence=min(base_confidence + 0.12, 0.95)))
|
||
|
||
plate_patterns = [
|
||
r"\b\d{2,3}\s*[가-힣]\s*\d{4}\b",
|
||
r"\b[A-ZА-Я]{1}\s?\d{3}\s?[A-ZА-Я]{2}\s?\d{2,3}\b",
|
||
r"\b[0-9A-ZА-Я가-힣][0-9A-ZА-Я가-힣\-\s]{4,10}\b",
|
||
]
|
||
for pattern in plate_patterns:
|
||
for raw in re.findall(pattern, upper):
|
||
normalized = normalize_license_plate(raw)
|
||
if normalized and 5 <= len(normalized) <= 10:
|
||
key = ("license_plate", normalized)
|
||
if key not in seen:
|
||
seen.add(key)
|
||
candidates.append(
|
||
OcrCandidate(type="license_plate", value=normalized, confidence=base_confidence)
|
||
)
|
||
|
||
return OcrResult(recognized_text=compact, candidates=candidates[:12], provider=provider)
|
||
|
||
|
||
@lru_cache
|
||
def get_ocr_provider() -> OCRProvider:
|
||
return CompositeOCRProvider()
|