This commit is contained in:
@@ -2,6 +2,7 @@ from collections import deque
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -15,6 +16,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from fastapi.exceptions import RequestErrorModel
|
||||
@@ -24,7 +26,8 @@ from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
@@ -43,6 +46,8 @@ sequence_annotation_to_type = {
|
||||
|
||||
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
Url: Type[Any]
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import TypeAdapter
|
||||
@@ -68,7 +73,7 @@ if PYDANTIC_V2:
|
||||
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||||
)
|
||||
|
||||
Required = PydanticUndefined
|
||||
RequiredParam = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
UndefinedType = PydanticUndefinedType
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
@@ -127,7 +132,7 @@ if PYDANTIC_V2:
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(), loc_prefix=loc
|
||||
errors=exc.errors(include_url=False), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
@@ -227,6 +232,10 @@ if PYDANTIC_V2:
|
||||
field_mapping, definitions = schema_generator.generate_definitions(
|
||||
inputs=inputs
|
||||
)
|
||||
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
|
||||
if "description" in item_def:
|
||||
item_description = cast(str, item_def["description"]).split("\f")[0]
|
||||
item_def["description"] = item_description
|
||||
return field_mapping, definitions # type: ignore[return-value]
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
@@ -249,7 +258,12 @@ if PYDANTIC_V2:
|
||||
return is_bytes_sequence_annotation(field.type_)
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
return type(field_info).from_annotation(annotation)
|
||||
cls = type(field_info)
|
||||
merged_field_info = cls.from_annotation(annotation)
|
||||
new_field_info = copy(field_info)
|
||||
new_field_info.metadata = merged_field_info.metadata
|
||||
new_field_info.annotation = merged_field_info.annotation
|
||||
return new_field_info
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = (
|
||||
@@ -261,7 +275,7 @@ if PYDANTIC_V2:
|
||||
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors()[0]
|
||||
).errors(include_url=False)[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
@@ -272,6 +286,12 @@ if PYDANTIC_V2:
|
||||
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return [
|
||||
ModelField(field_info=field_info, name=name)
|
||||
for name, field_info in model.model_fields.items()
|
||||
]
|
||||
|
||||
else:
|
||||
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||||
from pydantic import AnyUrl as Url # noqa: F401
|
||||
@@ -299,9 +319,10 @@ else:
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
ModelField as ModelField, # noqa: F401
|
||||
)
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Required as Required, # noqa: F401
|
||||
)
|
||||
|
||||
# Keeping old "Required" functionality from Pydantic V1, without
|
||||
# shadowing typing.Required.
|
||||
RequiredParam: Any = Ellipsis # type: ignore[no-redef]
|
||||
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||||
Undefined as Undefined,
|
||||
)
|
||||
@@ -372,9 +393,10 @@ else:
|
||||
)
|
||||
definitions.update(m_definitions)
|
||||
model_name = model_name_map[model]
|
||||
definitions[model_name] = m_schema
|
||||
for m_schema in definitions.values():
|
||||
if "description" in m_schema:
|
||||
m_schema["description"] = m_schema["description"].split("\f")[0]
|
||||
definitions[model_name] = m_schema
|
||||
return definitions
|
||||
|
||||
def is_pv1_scalar_field(field: ModelField) -> bool:
|
||||
@@ -506,6 +528,9 @@ else:
|
||||
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||||
return BodyModel
|
||||
|
||||
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return list(model.__fields__.values()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||||
@@ -525,6 +550,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
@@ -627,3 +658,8 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
||||
return get_model_fields(model)
|
||||
|
||||
Reference in New Issue
Block a user