API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -2,6 +2,7 @@ from collections import deque
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Any,
Callable,
@@ -15,6 +16,7 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
from fastapi.exceptions import RequestErrorModel
@@ -24,7 +26,8 @@ from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, Literal, get_args, get_origin
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
@@ -43,6 +46,8 @@ sequence_annotation_to_type = {
sequence_types = tuple(sequence_annotation_to_type.keys())
Url: Type[Any]
if PYDANTIC_V2:
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import TypeAdapter
@@ -68,7 +73,7 @@ if PYDANTIC_V2:
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
Required = PydanticUndefined
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
@@ -127,7 +132,7 @@ if PYDANTIC_V2:
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
errors=exc.errors(), loc_prefix=loc
errors=exc.errors(include_url=False), loc_prefix=loc
)
def serialize(
@@ -227,6 +232,10 @@ if PYDANTIC_V2:
field_mapping, definitions = schema_generator.generate_definitions(
inputs=inputs
)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
@@ -249,7 +258,12 @@ if PYDANTIC_V2:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return type(field_info).from_annotation(annotation)
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = (
@@ -261,7 +275,7 @@ if PYDANTIC_V2:
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors()[0]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
@@ -272,6 +286,12 @@ if PYDANTIC_V2:
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
else:
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic import AnyUrl as Url # noqa: F401
@@ -299,9 +319,10 @@ else:
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField, # noqa: F401
)
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Required as Required, # noqa: F401
)
# Keeping old "Required" functionality from Pydantic V1, without
# shadowing typing.Required.
RequiredParam: Any = Ellipsis # type: ignore[no-redef]
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
Undefined as Undefined,
)
@@ -372,9 +393,10 @@ else:
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
for m_schema in definitions.values():
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
definitions[model_name] = m_schema
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
@@ -506,6 +528,9 @@ else:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
@@ -525,6 +550,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if field_annotation_is_sequence(arg):
return True
return False
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
get_origin(annotation)
)
@@ -627,3 +658,8 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
for sub_annotation in get_args(annotation)
)
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return get_model_fields(model)