This commit is contained in:
@@ -1,25 +1,23 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import contextmanager
|
||||
from re import Pattern
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Iterator,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Self,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..aliases import AliasGenerator
|
||||
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
@@ -28,6 +26,7 @@ if not TYPE_CHECKING:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._internal._schema_generation_shared import GenerateSchema
|
||||
from ..fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||||
|
||||
@@ -57,10 +56,12 @@ class ConfigWrapper:
|
||||
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||||
# to construct error `loc`s, default `True`
|
||||
loc_by_alias: bool
|
||||
alias_generator: Callable[[str], str] | None
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
ignored_types: tuple[type, ...]
|
||||
allow_inf_nan: bool
|
||||
json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
json_encoders: dict[type[object], JsonEncoder] | None
|
||||
|
||||
# new in V2
|
||||
@@ -68,11 +69,13 @@ class ConfigWrapper:
|
||||
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||||
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||||
ser_json_timedelta: Literal['iso8601', 'float']
|
||||
ser_json_bytes: Literal['utf8', 'base64']
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
# whether to validate default values during validation, default False
|
||||
validate_default: bool
|
||||
validate_return: bool
|
||||
protected_namespaces: tuple[str, ...]
|
||||
protected_namespaces: tuple[str | Pattern[str], ...]
|
||||
hide_input_in_errors: bool
|
||||
defer_build: bool
|
||||
plugin_settings: dict[str, object] | None
|
||||
@@ -80,6 +83,13 @@ class ConfigWrapper:
|
||||
json_schema_serialization_defaults_required: bool
|
||||
json_schema_mode_override: Literal['validation', 'serialization', None]
|
||||
coerce_numbers_to_str: bool
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
validation_error_cause: bool
|
||||
use_attribute_docstrings: bool
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
validate_by_alias: bool
|
||||
validate_by_name: bool
|
||||
serialize_by_alias: bool
|
||||
|
||||
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||||
if check:
|
||||
@@ -113,13 +123,19 @@ class ConfigWrapper:
|
||||
config_class_from_namespace = namespace.get('Config')
|
||||
config_dict_from_namespace = namespace.get('model_config')
|
||||
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
|
||||
raise PydanticUserError(
|
||||
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||||
code='model-config-invalid-field-name',
|
||||
)
|
||||
|
||||
if config_class_from_namespace and config_dict_from_namespace:
|
||||
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||||
|
||||
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
|
||||
|
||||
if config_from_namespace is not None:
|
||||
config_new.update(config_from_namespace)
|
||||
config_new.update(config_from_namespace)
|
||||
|
||||
for k in list(kwargs.keys()):
|
||||
if k in config_keys:
|
||||
@@ -128,7 +144,7 @@ class ConfigWrapper:
|
||||
return cls(config_new)
|
||||
|
||||
# we don't show `__getattr__` to type checkers so missing attributes cause errors
|
||||
if not TYPE_CHECKING:
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
@@ -139,46 +155,77 @@ class ConfigWrapper:
|
||||
except KeyError:
|
||||
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||||
|
||||
def core_config(self, obj: Any) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
|
||||
|
||||
Pass `obj=None` if you do not want to attempt to infer the `title`.
|
||||
def core_config(self, title: str | None) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config.
|
||||
|
||||
We don't use getattr here since we don't want to populate with defaults.
|
||||
|
||||
Args:
|
||||
obj: An object used to populate `title` if not set in config.
|
||||
title: The title to use if not set in config.
|
||||
|
||||
Returns:
|
||||
A `CoreConfig` object created from config.
|
||||
"""
|
||||
config = self.config_dict
|
||||
|
||||
def dict_not_none(**kwargs: Any) -> Any:
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
core_config = core_schema.CoreConfig(
|
||||
**dict_not_none(
|
||||
title=self.config_dict.get('title') or (obj and obj.__name__),
|
||||
extra_fields_behavior=self.config_dict.get('extra'),
|
||||
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
|
||||
populate_by_name=self.config_dict.get('populate_by_name'),
|
||||
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
|
||||
str_to_lower=self.config_dict.get('str_to_lower'),
|
||||
str_to_upper=self.config_dict.get('str_to_upper'),
|
||||
strict=self.config_dict.get('strict'),
|
||||
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
|
||||
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
|
||||
from_attributes=self.config_dict.get('from_attributes'),
|
||||
loc_by_alias=self.config_dict.get('loc_by_alias'),
|
||||
revalidate_instances=self.config_dict.get('revalidate_instances'),
|
||||
validate_default=self.config_dict.get('validate_default'),
|
||||
str_max_length=self.config_dict.get('str_max_length'),
|
||||
str_min_length=self.config_dict.get('str_min_length'),
|
||||
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
|
||||
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
|
||||
if config.get('schema_generator') is not None:
|
||||
warnings.warn(
|
||||
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
|
||||
PydanticDeprecatedSince210,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if (populate_by_name := config.get('populate_by_name')) is not None:
|
||||
# We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0.
|
||||
# Thus, the above warning and this patch can be removed then as well.
|
||||
if config.get('validate_by_name') is None:
|
||||
config['validate_by_alias'] = True
|
||||
config['validate_by_name'] = populate_by_name
|
||||
|
||||
# We dynamically patch validate_by_name to be True if validate_by_alias is set to False
|
||||
# and validate_by_name is not explicitly set.
|
||||
if config.get('validate_by_alias') is False and config.get('validate_by_name') is None:
|
||||
config['validate_by_name'] = True
|
||||
|
||||
if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)):
|
||||
raise PydanticUserError(
|
||||
'At least one of `validate_by_alias` or `validate_by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
|
||||
return core_schema.CoreConfig(
|
||||
**{ # pyright: ignore[reportArgumentType]
|
||||
k: v
|
||||
for k, v in (
|
||||
('title', config.get('title') or title or None),
|
||||
('extra_fields_behavior', config.get('extra')),
|
||||
('allow_inf_nan', config.get('allow_inf_nan')),
|
||||
('str_strip_whitespace', config.get('str_strip_whitespace')),
|
||||
('str_to_lower', config.get('str_to_lower')),
|
||||
('str_to_upper', config.get('str_to_upper')),
|
||||
('strict', config.get('strict')),
|
||||
('ser_json_timedelta', config.get('ser_json_timedelta')),
|
||||
('ser_json_bytes', config.get('ser_json_bytes')),
|
||||
('val_json_bytes', config.get('val_json_bytes')),
|
||||
('ser_json_inf_nan', config.get('ser_json_inf_nan')),
|
||||
('from_attributes', config.get('from_attributes')),
|
||||
('loc_by_alias', config.get('loc_by_alias')),
|
||||
('revalidate_instances', config.get('revalidate_instances')),
|
||||
('validate_default', config.get('validate_default')),
|
||||
('str_max_length', config.get('str_max_length')),
|
||||
('str_min_length', config.get('str_min_length')),
|
||||
('hide_input_in_errors', config.get('hide_input_in_errors')),
|
||||
('coerce_numbers_to_str', config.get('coerce_numbers_to_str')),
|
||||
('regex_engine', config.get('regex_engine')),
|
||||
('validation_error_cause', config.get('validation_error_cause')),
|
||||
('cache_strings', config.get('cache_strings')),
|
||||
('validate_by_alias', config.get('validate_by_alias')),
|
||||
('validate_by_name', config.get('validate_by_name')),
|
||||
('serialize_by_alias', config.get('serialize_by_alias')),
|
||||
)
|
||||
if v is not None
|
||||
}
|
||||
)
|
||||
return core_config
|
||||
|
||||
def __repr__(self):
|
||||
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||||
@@ -195,22 +242,20 @@ class ConfigWrapperStack:
|
||||
def tail(self) -> ConfigWrapper:
|
||||
return self._config_wrapper_stack[-1]
|
||||
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
|
||||
@contextmanager
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
|
||||
if config_wrapper is None:
|
||||
return nullcontext()
|
||||
yield
|
||||
return
|
||||
|
||||
if not isinstance(config_wrapper, ConfigWrapper):
|
||||
config_wrapper = ConfigWrapper(config_wrapper, check=False)
|
||||
|
||||
@contextmanager
|
||||
def _context_manager() -> Iterator[None]:
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
|
||||
return _context_manager()
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
|
||||
|
||||
config_defaults = ConfigDict(
|
||||
@@ -230,6 +275,8 @@ config_defaults = ConfigDict(
|
||||
from_attributes=False,
|
||||
loc_by_alias=True,
|
||||
alias_generator=None,
|
||||
model_title_generator=None,
|
||||
field_title_generator=None,
|
||||
ignored_types=(),
|
||||
allow_inf_nan=True,
|
||||
json_schema_extra=None,
|
||||
@@ -237,17 +284,26 @@ config_defaults = ConfigDict(
|
||||
revalidate_instances='never',
|
||||
ser_json_timedelta='iso8601',
|
||||
ser_json_bytes='utf8',
|
||||
val_json_bytes='utf8',
|
||||
ser_json_inf_nan='null',
|
||||
validate_default=False,
|
||||
validate_return=False,
|
||||
protected_namespaces=('model_',),
|
||||
protected_namespaces=('model_validate', 'model_dump'),
|
||||
hide_input_in_errors=False,
|
||||
json_encoders=None,
|
||||
defer_build=False,
|
||||
plugin_settings=None,
|
||||
schema_generator=None,
|
||||
plugin_settings=None,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_mode_override=None,
|
||||
coerce_numbers_to_str=False,
|
||||
regex_engine='rust-regex',
|
||||
validation_error_cause=False,
|
||||
use_attribute_docstrings=False,
|
||||
cache_strings=True,
|
||||
validate_by_alias=True,
|
||||
validate_by_name=False,
|
||||
serialize_by_alias=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -288,7 +344,7 @@ V2_REMOVED_KEYS = {
|
||||
'post_init_call',
|
||||
}
|
||||
V2_RENAMED_KEYS = {
|
||||
'allow_population_by_field_name': 'populate_by_name',
|
||||
'allow_population_by_field_name': 'validate_by_name',
|
||||
'anystr_lower': 'str_to_lower',
|
||||
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||||
'anystr_upper': 'str_to_upper',
|
||||
|
||||
@@ -1,92 +1,97 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from warnings import warn
|
||||
|
||||
import typing_extensions
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._schema_generation_shared import (
|
||||
CoreSchemaOrField as CoreSchemaOrField,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from ..config import JsonDict, JsonSchemaExtraCallable
|
||||
from ._schema_generation_shared import (
|
||||
GetJsonSchemaFunction,
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(typing_extensions.TypedDict, total=False):
|
||||
class CoreMetadata(TypedDict, total=False):
|
||||
"""A `TypedDict` for holding the metadata dict of the schema.
|
||||
|
||||
Attributes:
|
||||
pydantic_js_functions: List of JSON schema functions.
|
||||
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
|
||||
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
|
||||
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
|
||||
prefer positional over keyword arguments for an 'arguments' schema.
|
||||
custom validation function. Only applies to before, plain, and wrap validators.
|
||||
pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type.
|
||||
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
|
||||
pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union.
|
||||
pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union
|
||||
when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time
|
||||
of the annotation application.
|
||||
|
||||
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
|
||||
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
|
||||
|
||||
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
|
||||
the core schema generation process. It's inevitable that we need to store some json schema related information
|
||||
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
|
||||
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
|
||||
"""
|
||||
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
|
||||
|
||||
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
|
||||
# prefer positional over keyword arguments for an 'arguments' schema.
|
||||
pydantic_js_prefer_positional_arguments: bool | None
|
||||
|
||||
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
|
||||
pydantic_js_prefer_positional_arguments: bool
|
||||
pydantic_js_updates: JsonDict
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
|
||||
pydantic_internal_union_tag_key: str
|
||||
pydantic_internal_union_discriminator: str
|
||||
|
||||
|
||||
class CoreMetadataHandler:
|
||||
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
|
||||
def update_core_metadata(
|
||||
core_metadata: Any,
|
||||
/,
|
||||
*,
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_updates: JsonDict | None = None,
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
|
||||
) -> None:
|
||||
from ..json_schema import PydanticJsonSchemaWarning
|
||||
|
||||
This class is used to interact with the metadata field on a CoreSchema object in a consistent
|
||||
way throughout pydantic.
|
||||
"""Update CoreMetadata instance in place. When we make modifications in this function, they
|
||||
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
|
||||
|
||||
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
|
||||
We do this here, instead of before / after each call to this function so that this typing hack
|
||||
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
|
||||
|
||||
For parameter descriptions, see `CoreMetadata` above.
|
||||
"""
|
||||
core_metadata = cast(CoreMetadata, core_metadata)
|
||||
|
||||
__slots__ = ('_schema',)
|
||||
if pydantic_js_functions:
|
||||
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
|
||||
|
||||
def __init__(self, schema: CoreSchemaOrField):
|
||||
self._schema = schema
|
||||
if pydantic_js_annotation_functions:
|
||||
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
|
||||
|
||||
metadata = schema.get('metadata')
|
||||
if metadata is None:
|
||||
schema['metadata'] = CoreMetadata()
|
||||
elif not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
if pydantic_js_updates:
|
||||
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
|
||||
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
|
||||
else:
|
||||
core_metadata['pydantic_js_updates'] = pydantic_js_updates
|
||||
|
||||
@property
|
||||
def metadata(self) -> CoreMetadata:
|
||||
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
|
||||
and raises an error if it is not a dict.
|
||||
"""
|
||||
metadata = self._schema.get('metadata')
|
||||
if metadata is None:
|
||||
self._schema['metadata'] = metadata = CoreMetadata()
|
||||
if not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
return metadata
|
||||
|
||||
|
||||
def build_metadata_dict(
|
||||
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
|
||||
js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_prefer_positional_arguments: bool | None = None,
|
||||
typed_dict_cls: type[Any] | None = None,
|
||||
initial_metadata: Any | None = None,
|
||||
) -> Any:
|
||||
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
|
||||
with the CoreMetadataHandler class.
|
||||
"""
|
||||
if initial_metadata is not None and not isinstance(initial_metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
|
||||
|
||||
metadata = CoreMetadata(
|
||||
pydantic_js_functions=js_functions or [],
|
||||
pydantic_js_annotation_functions=js_annotation_functions or [],
|
||||
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
|
||||
pydantic_typed_dict_cls=typed_dict_cls,
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
if initial_metadata is not None:
|
||||
metadata = {**initial_metadata, **metadata}
|
||||
|
||||
return metadata
|
||||
if pydantic_js_extra is not None:
|
||||
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
|
||||
if existing_pydantic_js_extra is None:
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
if isinstance(existing_pydantic_js_extra, dict):
|
||||
if isinstance(pydantic_js_extra, dict):
|
||||
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
|
||||
if callable(pydantic_js_extra):
|
||||
warn(
|
||||
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
|
||||
'The `callable` type is being ignored.'
|
||||
"If you'd like support for this behavior, please open an issue on pydantic.",
|
||||
PydanticJsonSchemaWarning,
|
||||
)
|
||||
if callable(existing_pydantic_js_extra):
|
||||
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Hashable,
|
||||
TypeVar,
|
||||
Union,
|
||||
_GenericAlias, # type: ignore
|
||||
cast,
|
||||
)
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core import validate_core_schema as _validate_core_schema
|
||||
from typing_extensions import TypeAliasType, TypeGuard, get_args
|
||||
from typing_extensions import TypeGuard, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from . import _repr
|
||||
from ._typing_extra import is_generic_alias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
AnyFunctionSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
@@ -39,19 +37,7 @@ CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
|
||||
|
||||
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
|
||||
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}
|
||||
|
||||
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
|
||||
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
|
||||
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
|
||||
schema building because one of it's members refers to a definition that was not yet defined when the union
|
||||
was first encountered.
|
||||
"""
|
||||
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
|
||||
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
|
||||
schema was first encountered.
|
||||
"""
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
|
||||
|
||||
|
||||
def is_core_schema(
|
||||
@@ -74,28 +60,27 @@ def is_function_with_inner_schema(
|
||||
|
||||
def is_list_like_schema_with_items_schema(
|
||||
schema: CoreSchema,
|
||||
) -> TypeGuard[
|
||||
core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
|
||||
]:
|
||||
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
|
||||
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
|
||||
|
||||
|
||||
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
"""Produces the ref to be used for this type by pydantic_core's core schemas.
|
||||
|
||||
This `args_override` argument was added for the purpose of creating valid recursive references
|
||||
when creating generic models without needing to create a concrete class.
|
||||
"""
|
||||
origin = type_
|
||||
args = get_args(type_) if isinstance(type_, _GenericAlias) else (args_override or ())
|
||||
origin = get_origin(type_) or type_
|
||||
|
||||
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
|
||||
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
|
||||
if generic_metadata:
|
||||
origin = generic_metadata['origin'] or origin
|
||||
args = generic_metadata['args'] or args
|
||||
|
||||
module_name = getattr(origin, '__module__', '<No __module__>')
|
||||
if isinstance(origin, TypeAliasType):
|
||||
type_ref = f'{module_name}.{origin.__name__}'
|
||||
if typing_objects.is_typealiastype(origin):
|
||||
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
|
||||
else:
|
||||
try:
|
||||
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
|
||||
@@ -124,457 +109,74 @@ def get_ref(s: core_schema.CoreSchema) -> None | str:
|
||||
return s.get('ref', None)
|
||||
|
||||
|
||||
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
|
||||
defs: dict[str, CoreSchema] = {}
|
||||
|
||||
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
ref = get_ref(s)
|
||||
if ref:
|
||||
defs[ref] = s
|
||||
return recurse(s, _record_valid_refs)
|
||||
|
||||
walk_core_schema(schema, _record_valid_refs)
|
||||
|
||||
return defs
|
||||
|
||||
|
||||
def define_expected_missing_refs(
|
||||
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
|
||||
) -> core_schema.CoreSchema | None:
|
||||
if not allowed_missing_refs:
|
||||
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
|
||||
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
|
||||
return None
|
||||
|
||||
refs = collect_definitions(schema).keys()
|
||||
|
||||
expected_missing_refs = allowed_missing_refs.difference(refs)
|
||||
if expected_missing_refs:
|
||||
definitions: list[core_schema.CoreSchema] = [
|
||||
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/619
|
||||
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
|
||||
for ref in expected_missing_refs
|
||||
]
|
||||
return core_schema.definitions_schema(schema, definitions)
|
||||
return None
|
||||
|
||||
|
||||
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
|
||||
invalid = False
|
||||
|
||||
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal invalid
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
|
||||
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
|
||||
return s
|
||||
return recurse(s, _is_schema_valid)
|
||||
|
||||
walk_core_schema(schema, _is_schema_valid)
|
||||
return invalid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
|
||||
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
|
||||
|
||||
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/615
|
||||
|
||||
|
||||
class _WalkCoreSchema:
|
||||
def __init__(self):
|
||||
self._schema_type_to_method = self._build_schema_type_to_method()
|
||||
|
||||
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
|
||||
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
|
||||
key: core_schema.CoreSchemaType
|
||||
for key in get_args(core_schema.CoreSchemaType):
|
||||
method_name = f"handle_{key.replace('-', '_')}_schema"
|
||||
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
|
||||
return mapping
|
||||
|
||||
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
return f(schema, self._walk)
|
||||
|
||||
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
|
||||
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
|
||||
if ser_schema:
|
||||
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
|
||||
return schema
|
||||
|
||||
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
sub_schema = schema.get('schema', None)
|
||||
if sub_schema is not None:
|
||||
schema['schema'] = self.walk(sub_schema, f) # type: ignore
|
||||
return schema
|
||||
|
||||
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
|
||||
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
|
||||
if schema is not None:
|
||||
ser_schema['schema'] = self.walk(schema, f) # type: ignore
|
||||
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
|
||||
if return_schema is not None:
|
||||
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
|
||||
return ser_schema
|
||||
|
||||
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_definitions: list[core_schema.CoreSchema] = []
|
||||
for definition in schema['definitions']:
|
||||
updated_definition = self.walk(definition, f)
|
||||
if 'ref' in updated_definition:
|
||||
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
|
||||
# This is most likely to happen due to replacing something with a definition reference, in
|
||||
# which case it should certainly not go in the definitions list
|
||||
new_definitions.append(updated_definition)
|
||||
new_inner_schema = self.walk(schema['schema'], f)
|
||||
|
||||
if not new_definitions and len(schema) == 3:
|
||||
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
|
||||
return new_inner_schema
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema['schema'] = new_inner_schema
|
||||
new_schema['definitions'] = new_definitions
|
||||
return new_schema
|
||||
|
||||
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_variable_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TupleVariableSchema, schema)
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_positional_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TuplePositionalSchema, schema)
|
||||
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
keys_schema = schema.get('keys_schema')
|
||||
if keys_schema is not None:
|
||||
schema['keys_schema'] = self.walk(keys_schema, f)
|
||||
values_schema = schema.get('values_schema')
|
||||
if values_schema:
|
||||
schema['values_schema'] = self.walk(values_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
if not is_function_with_inner_schema(schema):
|
||||
return schema
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
|
||||
for v in schema['choices']:
|
||||
if isinstance(v, tuple):
|
||||
new_choices.append((self.walk(v[0], f), v[1]))
|
||||
else:
|
||||
new_choices.append(self.walk(v, f))
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
for k, v in schema['choices'].items():
|
||||
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
|
||||
return schema
|
||||
|
||||
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
|
||||
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['json_schema'] = self.walk(schema['json_schema'], f)
|
||||
schema['python_schema'] = self.walk(schema['python_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_fields: dict[str, core_schema.ModelField] = {}
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
replaced_fields: dict[str, core_schema.TypedDictField] = {}
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_fields: list[core_schema.DataclassField] = []
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for field in schema['fields']:
|
||||
replaced_field = field.copy()
|
||||
replaced_field['schema'] = self.walk(field['schema'], f)
|
||||
replaced_fields.append(replaced_field)
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
|
||||
for param in schema['arguments_schema']:
|
||||
replaced_param = param.copy()
|
||||
replaced_param['schema'] = self.walk(param['schema'], f)
|
||||
replaced_arguments_schema.append(replaced_param)
|
||||
schema['arguments_schema'] = replaced_arguments_schema
|
||||
if 'var_args_schema' in schema:
|
||||
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
|
||||
if 'return_schema' in schema:
|
||||
schema['return_schema'] = self.walk(schema['return_schema'], f)
|
||||
return schema
|
||||
|
||||
|
||||
_dispatch = _WalkCoreSchema().walk
|
||||
|
||||
|
||||
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
"""Recursively traverse a CoreSchema.
|
||||
|
||||
Args:
|
||||
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
|
||||
f (Walk): A function to apply. This function takes two arguments:
|
||||
1. The current CoreSchema that is being processed
|
||||
(not the same one you passed into this function, one level down).
|
||||
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
|
||||
to pass data down the recursive calls without using globals or other mutable state.
|
||||
|
||||
Returns:
|
||||
core_schema.CoreSchema: A processed CoreSchema.
|
||||
"""
|
||||
return f(schema.copy(), _dispatch)
|
||||
|
||||
|
||||
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
|
||||
definitions: dict[str, core_schema.CoreSchema] = {}
|
||||
ref_counts: dict[str, int] = defaultdict(int)
|
||||
involved_in_recursion: dict[str, bool] = {}
|
||||
current_recursion_ref_count: dict[str, int] = defaultdict(int)
|
||||
|
||||
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definitions':
|
||||
for definition in s['definitions']:
|
||||
ref = get_ref(definition)
|
||||
assert ref is not None
|
||||
if ref not in definitions:
|
||||
definitions[ref] = definition
|
||||
recurse(definition, collect_refs)
|
||||
return recurse(s['schema'], collect_refs)
|
||||
else:
|
||||
ref = get_ref(s)
|
||||
if ref is not None:
|
||||
new = recurse(s, collect_refs)
|
||||
new_ref = get_ref(new)
|
||||
if new_ref:
|
||||
definitions[new_ref] = new
|
||||
return core_schema.definition_reference_schema(schema_ref=ref)
|
||||
else:
|
||||
return recurse(s, collect_refs)
|
||||
|
||||
schema = walk_core_schema(schema, collect_refs)
|
||||
|
||||
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] != 'definition-ref':
|
||||
return recurse(s, count_refs)
|
||||
ref = s['schema_ref']
|
||||
ref_counts[ref] += 1
|
||||
|
||||
if ref_counts[ref] >= 2:
|
||||
# If this model is involved in a recursion this should be detected
|
||||
# on its second encounter, we can safely stop the walk here.
|
||||
if current_recursion_ref_count[ref] != 0:
|
||||
involved_in_recursion[ref] = True
|
||||
return s
|
||||
|
||||
current_recursion_ref_count[ref] += 1
|
||||
recurse(definitions[ref], count_refs)
|
||||
current_recursion_ref_count[ref] -= 1
|
||||
return s
|
||||
|
||||
schema = walk_core_schema(schema, count_refs)
|
||||
|
||||
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
|
||||
|
||||
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
|
||||
if ref_counts[ref] > 1:
|
||||
return False
|
||||
if involved_in_recursion.get(ref, False):
|
||||
return False
|
||||
if 'serialization' in s:
|
||||
return False
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
for k in (
|
||||
'pydantic_js_functions',
|
||||
'pydantic_js_annotation_functions',
|
||||
'pydantic.internal.union_discriminator',
|
||||
):
|
||||
if k in metadata:
|
||||
# we need to keep this as a ref
|
||||
return False
|
||||
return True
|
||||
|
||||
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definition-ref':
|
||||
ref = s['schema_ref']
|
||||
# Check if the reference is only used once, not involved in recursion and does not have
|
||||
# any extra keys (like 'serialization')
|
||||
if can_be_inlined(s, ref):
|
||||
# Inline the reference by replacing the reference with the actual schema
|
||||
new = definitions.pop(ref)
|
||||
ref_counts[ref] -= 1 # because we just replaced it!
|
||||
# put all other keys that were on the def-ref schema into the inlined version
|
||||
# in particular this is needed for `serialization`
|
||||
if 'serialization' in s:
|
||||
new['serialization'] = s['serialization']
|
||||
s = recurse(new, inline_refs)
|
||||
return s
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
|
||||
schema = walk_core_schema(schema, inline_refs)
|
||||
|
||||
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
|
||||
|
||||
if def_values:
|
||||
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if os.getenv('PYDANTIC_VALIDATE_CORE_SCHEMAS'):
|
||||
return _validate_core_schema(schema)
|
||||
return schema
|
||||
|
||||
|
||||
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
|
||||
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
|
||||
s = s.copy()
|
||||
s.pop('metadata', None)
|
||||
if s['type'] == 'model-fields':
|
||||
s = s.copy()
|
||||
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
|
||||
for field_name, field_schema in s['fields'].items():
|
||||
field_schema.pop('metadata', None)
|
||||
s['fields'][field_name] = field_schema
|
||||
computed_fields = s.get('computed_fields', None)
|
||||
if computed_fields:
|
||||
s['computed_fields'] = [cf.copy() for cf in computed_fields]
|
||||
for cf in computed_fields:
|
||||
cf.pop('metadata', None)
|
||||
def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover
|
||||
"""A utility function to remove irrelevant information from a core schema."""
|
||||
if isinstance(obj, Mapping):
|
||||
new_dct = {}
|
||||
for k, v in obj.items():
|
||||
if k == 'metadata' and strip_metadata:
|
||||
new_metadata = {}
|
||||
|
||||
for meta_k, meta_v in v.items():
|
||||
if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'):
|
||||
new_metadata['js_metadata'] = '<stripped>'
|
||||
else:
|
||||
new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata)
|
||||
|
||||
if list(new_metadata.keys()) == ['js_metadata']:
|
||||
new_metadata = {'<stripped>'}
|
||||
|
||||
new_dct[k] = new_metadata
|
||||
# Remove some defaults:
|
||||
elif k in ('custom_init', 'root_model') and not v:
|
||||
continue
|
||||
else:
|
||||
s.pop('computed_fields', None)
|
||||
elif s['type'] == 'model':
|
||||
# remove some defaults
|
||||
if s.get('custom_init', True) is False:
|
||||
s.pop('custom_init')
|
||||
if s.get('root_model', True) is False:
|
||||
s.pop('root_model')
|
||||
if {'title'}.issuperset(s.get('config', {}).keys()):
|
||||
s.pop('config', None)
|
||||
new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata)
|
||||
|
||||
return recurse(s, strip_metadata)
|
||||
|
||||
return walk_core_schema(schema, strip_metadata)
|
||||
return new_dct
|
||||
elif isinstance(obj, Sequence) and not isinstance(obj, str):
|
||||
return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def pretty_print_core_schema(
|
||||
schema: CoreSchema,
|
||||
include_metadata: bool = False,
|
||||
) -> None:
|
||||
"""Pretty print a CoreSchema using rich.
|
||||
This is intended for debugging purposes.
|
||||
val: Any,
|
||||
*,
|
||||
console: Console | None = None,
|
||||
max_depth: int | None = None,
|
||||
strip_metadata: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
"""Pretty-print a core schema using the `rich` library.
|
||||
|
||||
Args:
|
||||
schema: The CoreSchema to print.
|
||||
include_metadata: Whether to include metadata in the output. Defaults to `False`.
|
||||
val: The core schema to print, or a Pydantic model/dataclass/type adapter
|
||||
(in which case the cached core schema is fetched and printed).
|
||||
console: A rich console to use when printing. Defaults to the global rich console instance.
|
||||
max_depth: The number of nesting levels which may be printed.
|
||||
strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata
|
||||
attributes will be stripped (but custom attributes are kept). Defaults to `True`.
|
||||
"""
|
||||
from rich import print # type: ignore # install it manually in your dev env
|
||||
# lazy import:
|
||||
from rich.pretty import pprint
|
||||
|
||||
if not include_metadata:
|
||||
schema = _strip_metadata(schema)
|
||||
# circ. imports:
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
|
||||
return print(schema)
|
||||
if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val):
|
||||
val = val.__pydantic_core_schema__
|
||||
if isinstance(val, TypeAdapter):
|
||||
val = val.core_schema
|
||||
cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata)
|
||||
|
||||
pprint(cleaned_schema, console=console, max_depth=max_depth)
|
||||
|
||||
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
|
||||
return schema
|
||||
return _validate_core_schema(schema)
|
||||
pps = pretty_print_core_schema
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
"""Private logic for creating pydantic dataclasses."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import typing
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import Any, Callable, ClassVar
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from pydantic_core import (
|
||||
ArgsKwargs,
|
||||
PydanticUndefined,
|
||||
SchemaSerializer,
|
||||
SchemaValidator,
|
||||
core_schema,
|
||||
@@ -19,28 +17,22 @@ from pydantic_core import (
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation
|
||||
from ..fields import FieldInfo
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from . import _config, _decorators, _discriminated_union, _typing_extra
|
||||
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
|
||||
from . import _config, _decorators
|
||||
from ._fields import collect_dataclass_fields
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generics import get_standard_typevars_map
|
||||
from ._mock_val_ser import set_dataclass_mock_validator
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._utils import is_valid_identifier
|
||||
from ._mock_val_ser import set_dataclass_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._utils import LazyClassAttribute
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance as StandardDataclass
|
||||
|
||||
from ..config import ConfigDict
|
||||
|
||||
class StandardDataclass(typing.Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
from ..fields import FieldInfo
|
||||
|
||||
class PydanticDataclass(StandardDataclass, typing.Protocol):
|
||||
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
|
||||
@@ -61,7 +53,10 @@ if typing.TYPE_CHECKING:
|
||||
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
|
||||
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
|
||||
__pydantic_serializer__: ClassVar[SchemaSerializer]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
|
||||
|
||||
@classmethod
|
||||
def __pydantic_fields_complete__(cls) -> bool: ...
|
||||
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
@@ -69,15 +64,22 @@ else:
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
|
||||
def set_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
ns_resolver: NsResolver | None = None,
|
||||
config_wrapper: _config.ConfigWrapper | None = None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
config_wrapper: The config wrapper instance, defaults to `None`.
|
||||
"""
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
|
||||
fields = collect_dataclass_fields(
|
||||
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
|
||||
)
|
||||
|
||||
cls.__pydantic_fields__ = fields # type: ignore
|
||||
|
||||
@@ -87,7 +89,8 @@ def complete_dataclass(
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
_force_build: bool = False,
|
||||
) -> bool:
|
||||
"""Finish building a pydantic dataclass.
|
||||
|
||||
@@ -99,7 +102,10 @@ def complete_dataclass(
|
||||
cls: The class.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
types_namespace: The types namespace.
|
||||
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
|
||||
and during schema building.
|
||||
_force_build: Whether to force building the dataclass, no matter if
|
||||
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
|
||||
|
||||
Returns:
|
||||
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
|
||||
@@ -107,136 +113,94 @@ def complete_dataclass(
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
|
||||
"""
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
if types_namespace is None:
|
||||
types_namespace = _typing_extra.get_cls_types_namespace(cls)
|
||||
|
||||
set_dataclass_fields(cls, types_namespace)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
|
||||
original_init = cls.__init__
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
|
||||
# and so that the mock validator is used if building was deferred:
|
||||
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
|
||||
__tracebackhide__ = True
|
||||
s = __dataclass_self__
|
||||
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
|
||||
|
||||
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
|
||||
sig = generate_dataclass_signature(cls)
|
||||
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__signature__ = sig # type: ignore
|
||||
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
|
||||
|
||||
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
|
||||
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)
|
||||
|
||||
if not _force_build and config_wrapper.defer_build:
|
||||
set_dataclass_mocks(cls)
|
||||
return False
|
||||
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver=ns_resolver,
|
||||
typevars_map=typevars_map,
|
||||
)
|
||||
|
||||
# set __signature__ attr only for the class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
# It's important that we reference the `original_init` here,
|
||||
# as it is the one synthesized by the stdlib `dataclass` module:
|
||||
init=original_init,
|
||||
fields=cls.__pydantic_fields__, # type: ignore
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
is_dataclass=True,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
if get_core_schema:
|
||||
schema = get_core_schema(
|
||||
cls,
|
||||
CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
),
|
||||
)
|
||||
else:
|
||||
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_dataclass_mock_validator(cls, cls.__name__, f'`{e.name}`')
|
||||
set_dataclass_mocks(cls, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
schema = gen_schema.collect_definitions(schema)
|
||||
if collect_invalid_schemas(schema):
|
||||
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except InvalidSchemaError:
|
||||
set_dataclass_mocks(cls)
|
||||
return False
|
||||
|
||||
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
|
||||
|
||||
# We are about to set all the remaining required properties expected for this cast;
|
||||
# __pydantic_decorators__ and __pydantic_fields__ should already be set
|
||||
cls = typing.cast('type[PydanticDataclass]', cls)
|
||||
# debug(schema)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
cls.__pydantic_validator__ = validator = create_schema_validator(
|
||||
schema, core_config, config_wrapper.plugin_settings
|
||||
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
|
||||
if config_wrapper.validate_assignment:
|
||||
|
||||
@wraps(cls.__setattr__)
|
||||
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
|
||||
validator.validate_assignment(instance, __field, __value)
|
||||
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
|
||||
validator.validate_assignment(instance, field, value)
|
||||
|
||||
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
return True
|
||||
|
||||
|
||||
def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
|
||||
"""Generate signature for a pydantic dataclass.
|
||||
|
||||
This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
|
||||
If we change this eventually, we should make this function's logic more closely mirror that from
|
||||
`pydantic._internal._model_construction.generate_model_signature`.
|
||||
|
||||
Args:
|
||||
cls: The dataclass.
|
||||
|
||||
Returns:
|
||||
The signature.
|
||||
"""
|
||||
sig = signature(cls)
|
||||
final_params: dict[str, Parameter] = {}
|
||||
|
||||
for param in sig.parameters.values():
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field name with the alias if present
|
||||
name = param.name
|
||||
alias = param_default.alias
|
||||
validation_alias = param_default.validation_alias
|
||||
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
|
||||
name = alias
|
||||
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
|
||||
name = validation_alias
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = inspect.Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
|
||||
param = param.replace(annotation=annotation, name=name, default=default)
|
||||
final_params[param.name] = param
|
||||
|
||||
return Signature(parameters=list(final_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
|
||||
|
||||
@@ -245,7 +209,7 @@ def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
|
||||
- `_cls` does not have any annotations that are not dataclass fields
|
||||
e.g.
|
||||
```py
|
||||
```python
|
||||
import dataclasses
|
||||
|
||||
import pydantic.dataclasses
|
||||
|
||||
@@ -1,31 +1,30 @@
|
||||
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, partialmethod
|
||||
from functools import cached_property, partial, partialmethod
|
||||
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from typing_extensions import Literal, TypeAlias, is_typeddict
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
|
||||
from typing_extensions import TypeAlias, is_typeddict
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace
|
||||
from ._typing_extra import get_function_type_hints
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ..functional_validators import FieldValidatorModes
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
# python 3.7
|
||||
cached_property = None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ValidatorDecoratorInfo:
|
||||
@@ -61,6 +60,9 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_validator'
|
||||
@@ -68,6 +70,7 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: tuple[str, ...]
|
||||
mode: FieldValidatorModes
|
||||
check_fields: bool | None
|
||||
json_schema_input_type: Any
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
@@ -132,7 +135,7 @@ class ModelValidatorDecoratorInfo:
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
|
||||
decorator_repr: A class variable representing the decorator string, '@model_validator'.
|
||||
mode: The proposed serializer mode.
|
||||
"""
|
||||
|
||||
@@ -140,7 +143,7 @@ class ModelValidatorDecoratorInfo:
|
||||
mode: Literal['wrap', 'before', 'after']
|
||||
|
||||
|
||||
DecoratorInfo = Union[
|
||||
DecoratorInfo: TypeAlias = """Union[
|
||||
ValidatorDecoratorInfo,
|
||||
FieldValidatorDecoratorInfo,
|
||||
RootValidatorDecoratorInfo,
|
||||
@@ -148,7 +151,7 @@ DecoratorInfo = Union[
|
||||
ModelSerializerDecoratorInfo,
|
||||
ModelValidatorDecoratorInfo,
|
||||
ComputedFieldInfo,
|
||||
]
|
||||
]"""
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
DecoratedType: TypeAlias = (
|
||||
@@ -183,6 +186,12 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
|
||||
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
|
||||
self.wrapped = getattr(self.wrapped, name)(func)
|
||||
if isinstance(self.wrapped, property):
|
||||
# update ComputedFieldInfo.wrapped_property
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
if isinstance(self.decorator_info, ComputedFieldInfo):
|
||||
self.decorator_info.wrapped_property = self.wrapped
|
||||
return self
|
||||
|
||||
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
|
||||
@@ -194,11 +203,11 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
|
||||
def __set_name__(self, instance: Any, name: str) -> None:
|
||||
if hasattr(self.wrapped, '__set_name__'):
|
||||
self.wrapped.__set_name__(instance, name)
|
||||
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
def __getattr__(self, name: str, /) -> Any:
|
||||
"""Forward checks for __isabstractmethod__ and such."""
|
||||
return getattr(self.wrapped, __name)
|
||||
return getattr(self.wrapped, name)
|
||||
|
||||
|
||||
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
|
||||
@@ -488,6 +497,8 @@ class DecoratorInfos:
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
else:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
isinstance(var_value, ComputedFieldInfo)
|
||||
res.computed_fields[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=None, info=info
|
||||
@@ -498,7 +509,7 @@ class DecoratorInfos:
|
||||
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
|
||||
# and replace them with the thing they are wrapping (see the other setattr call below)
|
||||
# which allows validator class methods to also function as regular class methods
|
||||
setattr(model_dc, '__pydantic_decorators__', res)
|
||||
model_dc.__pydantic_decorators__ = res
|
||||
for name, value in to_replace:
|
||||
setattr(model_dc, name, value)
|
||||
return res
|
||||
@@ -518,12 +529,11 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
"""
|
||||
try:
|
||||
sig = signature(validator)
|
||||
except ValueError:
|
||||
# builtins and some C extensions don't have signatures
|
||||
# assume that they don't take an info argument and only take a single argument
|
||||
# e.g. `str.strip` or `datetime.datetime`
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
n_positional = count_positional_params(sig)
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if mode == 'wrap':
|
||||
if n_positional == 3:
|
||||
return True
|
||||
@@ -542,9 +552,7 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
)
|
||||
|
||||
|
||||
def inspect_field_serializer(
|
||||
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
|
||||
) -> tuple[bool, bool]:
|
||||
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
|
||||
"""Look at a field serializer function and determine if it is a field serializer,
|
||||
and whether it takes an info argument.
|
||||
|
||||
@@ -553,18 +561,21 @@ def inspect_field_serializer(
|
||||
Args:
|
||||
serializer: The serializer function to inspect.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
computed_field: When serializer is applied on computed_field. It doesn't require
|
||||
info signature.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_field_serializer, info_arg).
|
||||
"""
|
||||
sig = signature(serializer)
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present and this is not a method:
|
||||
return (False, False)
|
||||
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
is_field_serializer = first is not None and first.name == 'self'
|
||||
|
||||
n_positional = count_positional_params(sig)
|
||||
n_positional = count_positional_required_params(sig)
|
||||
if is_field_serializer:
|
||||
# -1 to correct for self parameter
|
||||
info_arg = _serializer_info_arg(mode, n_positional - 1)
|
||||
@@ -576,13 +587,8 @@ def inspect_field_serializer(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
if info_arg and computed_field:
|
||||
raise PydanticUserError(
|
||||
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
|
||||
)
|
||||
|
||||
else:
|
||||
return is_field_serializer, info_arg
|
||||
return is_field_serializer, info_arg
|
||||
|
||||
|
||||
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
@@ -597,8 +603,13 @@ def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['
|
||||
Returns:
|
||||
info_arg
|
||||
"""
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -626,7 +637,7 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
)
|
||||
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -639,18 +650,18 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
|
||||
if mode == 'plain':
|
||||
if n_positional == 1:
|
||||
# (__input_value: Any) -> Any
|
||||
# (input_value: Any, /) -> Any
|
||||
return False
|
||||
elif n_positional == 2:
|
||||
# (__model: Any, __input_value: Any) -> Any
|
||||
# (model: Any, input_value: Any, /) -> Any
|
||||
return True
|
||||
else:
|
||||
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
|
||||
if n_positional == 2:
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
|
||||
return False
|
||||
elif n_positional == 3:
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
|
||||
return True
|
||||
|
||||
return None
|
||||
@@ -711,34 +722,25 @@ def unwrap_wrapped_function(
|
||||
unwrap_class_static_method: bool = True,
|
||||
) -> Any:
|
||||
"""Recursively unwraps a wrapped function until the underlying function is reached.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
|
||||
|
||||
Args:
|
||||
func: The function to unwrap.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
|
||||
decorators.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
|
||||
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
|
||||
decorators. If False, only unwrap partial and partialmethod decorators.
|
||||
|
||||
Returns:
|
||||
The underlying function of the wrapped function.
|
||||
"""
|
||||
all: set[Any] = {property}
|
||||
# Define the types we want to check against as a single tuple.
|
||||
unwrap_types = (
|
||||
(property, cached_property)
|
||||
+ ((partial, partialmethod) if unwrap_partial else ())
|
||||
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
|
||||
)
|
||||
|
||||
if unwrap_partial:
|
||||
all.update({partial, partialmethod})
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
cached_property = type('', (), {})
|
||||
else:
|
||||
all.add(cached_property)
|
||||
|
||||
if unwrap_class_static_method:
|
||||
all.update({staticmethod, classmethod})
|
||||
|
||||
while isinstance(func, tuple(all)):
|
||||
while isinstance(func, unwrap_types):
|
||||
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
|
||||
func = func.__func__
|
||||
elif isinstance(func, (partial, partialmethod)):
|
||||
@@ -753,38 +755,72 @@ def unwrap_wrapped_function(
|
||||
return func
|
||||
|
||||
|
||||
def get_function_return_type(
|
||||
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Get the function return type.
|
||||
_function_like = (
|
||||
partial,
|
||||
partialmethod,
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
types.MethodType,
|
||||
types.WrapperDescriptorType,
|
||||
types.MethodWrapperType,
|
||||
types.MemberDescriptorType,
|
||||
)
|
||||
|
||||
It gets the return type from the type annotation if `explicit_return_type` is `None`.
|
||||
Otherwise, it returns `explicit_return_type`.
|
||||
|
||||
def get_callable_return_type(
|
||||
callable_obj: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any | PydanticUndefinedType:
|
||||
"""Get the callable return type.
|
||||
|
||||
Args:
|
||||
func: The function to get its return type.
|
||||
explicit_return_type: The explicit return type.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
callable_obj: The callable to analyze.
|
||||
globalns: The globals namespace to use during type annotation evaluation.
|
||||
localns: The locals namespace to use during type annotation evaluation.
|
||||
|
||||
Returns:
|
||||
The function return type.
|
||||
"""
|
||||
if explicit_return_type is PydanticUndefined:
|
||||
# try to get it from the type annotation
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
else:
|
||||
return explicit_return_type
|
||||
if isinstance(callable_obj, type):
|
||||
# types are callables, and we assume the return type
|
||||
# is the type itself (e.g. `int()` results in an instance of `int`).
|
||||
return callable_obj
|
||||
|
||||
if not isinstance(callable_obj, _function_like):
|
||||
call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
|
||||
if call_func is not None:
|
||||
callable_obj = call_func
|
||||
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(callable_obj),
|
||||
include_keys={'return'},
|
||||
globalns=globalns,
|
||||
localns=localns,
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
|
||||
|
||||
def count_positional_params(sig: Signature) -> int:
|
||||
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
|
||||
def count_positional_required_params(sig: Signature) -> int:
|
||||
"""Get the number of positional (required) arguments of a signature.
|
||||
|
||||
This function should only be used to inspect signatures of validation and serialization functions.
|
||||
The first argument (the value being serialized or validated) is counted as a required argument
|
||||
even if a default value exists.
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
Returns:
|
||||
The number of positional arguments of a signature.
|
||||
"""
|
||||
parameters = list(sig.parameters.values())
|
||||
return sum(
|
||||
1
|
||||
for param in parameters
|
||||
if can_be_positional(param)
|
||||
# First argument is the value being validated/serialized, and can have a default value
|
||||
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
|
||||
# for instance) should be required, and thus without any default value.
|
||||
and (param.default is Parameter.empty or param is parameters[0])
|
||||
)
|
||||
|
||||
|
||||
def ensure_property(f: Any) -> Any:
|
||||
|
||||
@@ -1,49 +1,45 @@
|
||||
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Dict, Tuple, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._decorators import can_be_positional
|
||||
from ._utils import can_be_positional
|
||||
|
||||
|
||||
class V1OnlyValueValidator(Protocol):
|
||||
"""A simple validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValues(Protocol):
|
||||
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesKwOnly(Protocol):
|
||||
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithKwargs(Protocol):
|
||||
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesAndKwargs(Protocol):
|
||||
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
V1Validator = Union[
|
||||
@@ -109,23 +105,21 @@ def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithI
|
||||
return wrapper2
|
||||
|
||||
|
||||
RootValidatorValues = Dict[str, Any]
|
||||
RootValidatorValues = dict[str, Any]
|
||||
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
|
||||
RootValidatorFieldsTuple = Tuple[Any, ...]
|
||||
RootValidatorFieldsTuple = tuple[Any, ...]
|
||||
|
||||
|
||||
class V1RootValidatorFunction(Protocol):
|
||||
"""A simple root validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
|
||||
...
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreBeforeRootValidator(Protocol):
|
||||
"""V2 validator with mode='before'."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
|
||||
...
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
|
||||
|
||||
|
||||
class V2CoreAfterRootValidator(Protocol):
|
||||
@@ -133,8 +127,7 @@ class V2CoreAfterRootValidator(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
|
||||
) -> RootValidatorFieldsTuple:
|
||||
...
|
||||
) -> RootValidatorFieldsTuple: ...
|
||||
|
||||
|
||||
def make_v1_generic_root_validator(
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import Any, Hashable, Sequence
|
||||
from collections.abc import Hashable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _core_utils
|
||||
from ._core_utils import (
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
|
||||
CoreSchemaField,
|
||||
collect_definitions,
|
||||
simplify_schema_references,
|
||||
)
|
||||
|
||||
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Discriminator
|
||||
from ._core_metadata import CoreMetadata
|
||||
|
||||
|
||||
class MissingDefinitionForUnionRef(Exception):
|
||||
@@ -26,39 +26,15 @@ class MissingDefinitionForUnionRef(Exception):
|
||||
super().__init__(f'Missing definition for ref {self.ref!r}')
|
||||
|
||||
|
||||
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
|
||||
schema.setdefault('metadata', {})
|
||||
metadata = schema.get('metadata')
|
||||
assert metadata is not None
|
||||
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
|
||||
|
||||
|
||||
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
definitions: dict[str, CoreSchema] | None = None
|
||||
|
||||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal definitions
|
||||
if 'metadata' in s:
|
||||
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
|
||||
return s
|
||||
|
||||
s = recurse(s, inner)
|
||||
if s['type'] == 'tagged-union':
|
||||
return s
|
||||
|
||||
metadata = s.get('metadata', {})
|
||||
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
|
||||
if discriminator is not None:
|
||||
if definitions is None:
|
||||
definitions = collect_definitions(schema)
|
||||
s = apply_discriminator(s, discriminator, definitions)
|
||||
return s
|
||||
|
||||
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
|
||||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
|
||||
metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
|
||||
metadata['pydantic_internal_union_discriminator'] = discriminator
|
||||
|
||||
|
||||
def apply_discriminator(
|
||||
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
|
||||
schema: core_schema.CoreSchema,
|
||||
discriminator: str | Discriminator,
|
||||
definitions: dict[str, core_schema.CoreSchema] | None = None,
|
||||
) -> core_schema.CoreSchema:
|
||||
"""Applies the discriminator and returns a new core schema.
|
||||
|
||||
@@ -83,6 +59,14 @@ def apply_discriminator(
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
from ..types import Discriminator
|
||||
|
||||
if isinstance(discriminator, Discriminator):
|
||||
if isinstance(discriminator.discriminator, str):
|
||||
discriminator = discriminator.discriminator
|
||||
else:
|
||||
return discriminator._convert_schema(schema)
|
||||
|
||||
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
|
||||
|
||||
|
||||
@@ -150,7 +134,7 @@ class _ApplyInferredDiscriminator:
|
||||
# in the output TaggedUnionSchema that will replace the union from the input schema
|
||||
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
|
||||
self._used = False
|
||||
|
||||
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -176,16 +160,11 @@ class _ApplyInferredDiscriminator:
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
self.definitions.update(collect_definitions(schema))
|
||||
assert not self._used
|
||||
schema = self._apply_to_root(schema)
|
||||
if self._should_be_nullable and not self._is_nullable:
|
||||
schema = core_schema.nullable_schema(schema)
|
||||
self._used = True
|
||||
new_defs = collect_definitions(schema)
|
||||
missing_defs = self.definitions.keys() - new_defs.keys()
|
||||
if missing_defs:
|
||||
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
|
||||
return schema
|
||||
|
||||
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -255,6 +234,10 @@ class _ApplyInferredDiscriminator:
|
||||
* Validating that each allowed discriminator value maps to a unique choice
|
||||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
|
||||
"""
|
||||
if choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
|
||||
if choice['type'] == 'none':
|
||||
self._should_be_nullable = True
|
||||
elif choice['type'] == 'definitions':
|
||||
@@ -266,10 +249,6 @@ class _ApplyInferredDiscriminator:
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
elif choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
self._handle_choice(self.definitions[choice['schema_ref']])
|
||||
elif choice['type'] not in {
|
||||
'model',
|
||||
'typed-dict',
|
||||
@@ -277,12 +256,16 @@ class _ApplyInferredDiscriminator:
|
||||
'lax-or-strict',
|
||||
'dataclass',
|
||||
'dataclass-args',
|
||||
'definition-ref',
|
||||
} and not _core_utils.is_function_with_inner_schema(choice):
|
||||
# We should eventually handle 'definition-ref' as well
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
else:
|
||||
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
|
||||
# In this case, this inner tagged-union is compatible with the outer tagged-union,
|
||||
@@ -316,13 +299,10 @@ class _ApplyInferredDiscriminator:
|
||||
"""
|
||||
if choice['type'] == 'definitions':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
elif choice['type'] == 'function-plain':
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
|
||||
elif _core_utils.is_function_with_inner_schema(choice):
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'lax-or-strict':
|
||||
return sorted(
|
||||
set(
|
||||
@@ -373,10 +353,13 @@ class _ApplyInferredDiscriminator:
|
||||
raise MissingDefinitionForUnionRef(schema_ref)
|
||||
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
|
||||
def _infer_discriminator_values_for_typed_dict_choice(
|
||||
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Utilities related to attribute docstring extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DocstringVisitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target: str | None = None
|
||||
self.attrs: dict[str, str] = {}
|
||||
self.previous_node_type: type[ast.AST] | None = None
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
node_result = super().visit(node)
|
||||
self.previous_node_type = type(node)
|
||||
return node_result
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.target = node.target.id
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
if (
|
||||
isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)
|
||||
and self.previous_node_type is ast.AnnAssign
|
||||
):
|
||||
docstring = inspect.cleandoc(node.value.value)
|
||||
if self.target:
|
||||
self.attrs[self.target] = docstring
|
||||
self.target = None
|
||||
|
||||
|
||||
def _dedent_source_lines(source: list[str]) -> str:
|
||||
# Required for nested class definitions, e.g. in a function block
|
||||
dedent_source = textwrap.dedent(''.join(source))
|
||||
if dedent_source.startswith((' ', '\t')):
|
||||
# We are in the case where there's a dedented (usually multiline) string
|
||||
# at a lower indentation level than the class itself. We wrap our class
|
||||
# in a function as a workaround.
|
||||
dedent_source = f'def dedent_workaround():\n{dedent_source}'
|
||||
return dedent_source
|
||||
|
||||
|
||||
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
while frame:
|
||||
if inspect.getmodule(frame) is inspect.getmodule(cls):
|
||||
lnum = frame.f_lineno
|
||||
try:
|
||||
lines, _ = inspect.findsource(frame)
|
||||
except OSError: # pragma: no cover
|
||||
# Source can't be retrieved (maybe because running in an interactive terminal),
|
||||
# we don't want to error here.
|
||||
pass
|
||||
else:
|
||||
block_lines = inspect.getblock(lines[lnum - 1 :])
|
||||
dedent_source = _dedent_source_lines(block_lines)
|
||||
try:
|
||||
block_tree = ast.parse(dedent_source)
|
||||
except SyntaxError:
|
||||
pass
|
||||
else:
|
||||
stmt = block_tree.body[0]
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
|
||||
# `_dedent_source_lines` wrapped the class around the workaround function
|
||||
stmt = stmt.body[0]
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
|
||||
return block_lines
|
||||
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
|
||||
"""Map model attributes and their corresponding docstring.
|
||||
|
||||
Args:
|
||||
cls: The class of the Pydantic model to inspect.
|
||||
use_inspect: Whether to skip usage of frames to find the object and use
|
||||
the `inspect` module instead.
|
||||
|
||||
Returns:
|
||||
A mapping containing attribute names and their corresponding docstring.
|
||||
"""
|
||||
if use_inspect:
|
||||
# Might not work as expected if two classes have the same name in the same source file.
|
||||
try:
|
||||
source, _ = inspect.getsourcelines(cls)
|
||||
except OSError: # pragma: no cover
|
||||
return {}
|
||||
else:
|
||||
source = _extract_source_from_frame(cls)
|
||||
|
||||
if not source:
|
||||
return {}
|
||||
|
||||
dedent_source = _dedent_source_lines(source)
|
||||
|
||||
visitor = DocstringVisitor()
|
||||
visitor.visit(ast.parse(dedent_source))
|
||||
return visitor.attrs
|
||||
@@ -1,92 +1,104 @@
|
||||
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from functools import cache
|
||||
from inspect import Parameter, ismethoddescriptor, signature
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from annotated_types import BaseMetadata
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeIs, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import AnnotationSource
|
||||
|
||||
from . import _typing_extra
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
from pydantic.errors import PydanticUserError
|
||||
|
||||
from . import _generics, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._docs_extraction import extract_docstrings_from_cls
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._repr import Representation
|
||||
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
from ..fields import FieldInfo
|
||||
from ..main import BaseModel
|
||||
from ._dataclasses import StandardDataclass
|
||||
from ._dataclasses import PydanticDataclass, StandardDataclass
|
||||
from ._decorators import DecoratorInfos
|
||||
|
||||
|
||||
def get_type_hints_infer_globalns(
|
||||
obj: Any,
|
||||
localns: dict[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Gets type hints for an object by inferring the global namespace.
|
||||
|
||||
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
|
||||
global namespace from `obj.__module__` if it is not `None`.
|
||||
|
||||
Args:
|
||||
obj: The object to get its type hints.
|
||||
localns: The local namespaces.
|
||||
include_extras: Whether to recursively include annotation metadata.
|
||||
|
||||
Returns:
|
||||
The object type hints.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
globalns: dict[str, Any] | None = None
|
||||
if module_name:
|
||||
try:
|
||||
globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
|
||||
|
||||
|
||||
class PydanticMetadata(Representation):
|
||||
"""Base class for annotation markers like `Strict`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metada like `max_digits`."""
|
||||
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
|
||||
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
|
||||
|
||||
def __init__(self, **metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
Args:
|
||||
**metadata: The metadata to add.
|
||||
|
||||
Returns:
|
||||
The new `_PydanticGeneralMetadata` class.
|
||||
"""
|
||||
return _general_metadata_cls()(metadata) # type: ignore
|
||||
|
||||
|
||||
@cache
|
||||
def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
"""Do it this way to avoid importing `annotated_types` at import time."""
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metadata like `max_digits`."""
|
||||
|
||||
def __init__(self, metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
|
||||
return _PydanticGeneralMetadata # type: ignore
|
||||
|
||||
|
||||
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None:
|
||||
fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect)
|
||||
for ann_name, field_info in fields.items():
|
||||
if field_info.description is None and ann_name in fields_docs:
|
||||
field_info.description = fields_docs[ann_name]
|
||||
|
||||
|
||||
def collect_model_fields( # noqa: C901
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
ns_resolver: NsResolver | None,
|
||||
*,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
typevars_map: Mapping[TypeVar, Any] | None = None,
|
||||
) -> tuple[dict[str, FieldInfo], set[str]]:
|
||||
"""Collect the fields of a nascent pydantic model.
|
||||
"""Collect the fields and class variables names of a nascent Pydantic model.
|
||||
|
||||
Also collect the names of any ClassVars present in the type hints.
|
||||
The fields collection process is *lenient*, meaning it won't error if string annotations
|
||||
fail to evaluate. If this happens, the original annotation (and assigned value, if any)
|
||||
is stored on the created `FieldInfo` instance.
|
||||
|
||||
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
|
||||
The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model),
|
||||
and will make use of these stored attributes.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
A tuple contains fields and class variables.
|
||||
A two-tuple containing model fields and class variables names.
|
||||
|
||||
Raises:
|
||||
NameError:
|
||||
@@ -94,9 +106,16 @@ def collect_model_fields( # noqa: C901
|
||||
- If there is a field other than `root` in `RootModel`.
|
||||
- If a field shadows an attribute in the parent model.
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
BaseModel = import_cached_base_model()
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
|
||||
bases = cls.__bases__
|
||||
parent_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for base in reversed(bases):
|
||||
if model_fields := getattr(base, '__pydantic_fields__', None):
|
||||
parent_fields_lookup.update(model_fields)
|
||||
|
||||
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
|
||||
|
||||
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
||||
# annotations is only used for finding fields in parent classes
|
||||
@@ -104,39 +123,50 @@ def collect_model_fields( # noqa: C901
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
|
||||
class_vars: set[str] = set()
|
||||
for ann_name, ann_type in type_hints.items():
|
||||
for ann_name, (ann_type, evaluated) in type_hints.items():
|
||||
if ann_name == 'model_config':
|
||||
# We never want to treat `model_config` as a field
|
||||
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
|
||||
# protected namespaces (where `model_config` might be allowed as a field name)
|
||||
continue
|
||||
|
||||
for protected_namespace in config_wrapper.protected_namespaces:
|
||||
if ann_name.startswith(protected_namespace):
|
||||
ns_violation: bool = False
|
||||
if isinstance(protected_namespace, Pattern):
|
||||
ns_violation = protected_namespace.match(ann_name) is not None
|
||||
elif isinstance(protected_namespace, str):
|
||||
ns_violation = ann_name.startswith(protected_namespace)
|
||||
|
||||
if ns_violation:
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
from ..main import BaseModel
|
||||
|
||||
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
|
||||
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
|
||||
raise NameError(
|
||||
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace "{protected_namespace}".'
|
||||
)
|
||||
else:
|
||||
valid_namespaces = tuple(
|
||||
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
|
||||
)
|
||||
valid_namespaces = ()
|
||||
for pn in config_wrapper.protected_namespaces:
|
||||
if isinstance(pn, Pattern):
|
||||
if not pn.match(ann_name):
|
||||
valid_namespaces += (f're.compile({pn.pattern})',)
|
||||
else:
|
||||
if not ann_name.startswith(pn):
|
||||
valid_namespaces += (pn,)
|
||||
|
||||
warnings.warn(
|
||||
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
|
||||
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
|
||||
'\n\nYou may be able to resolve this warning by setting'
|
||||
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
|
||||
UserWarning,
|
||||
)
|
||||
if is_classvar(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
assigned_value = getattr(cls, ann_name, PydanticUndefined)
|
||||
|
||||
if not is_valid_field_name(ann_name):
|
||||
continue
|
||||
if cls.__pydantic_root_model__ and ann_name != 'root':
|
||||
@@ -145,7 +175,7 @@ def collect_model_fields( # noqa: C901
|
||||
)
|
||||
|
||||
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
|
||||
# "... shadows an attribute" errors
|
||||
# "... shadows an attribute" warnings
|
||||
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
for base in bases:
|
||||
dataclass_fields = {
|
||||
@@ -153,42 +183,77 @@ def collect_model_fields( # noqa: C901
|
||||
}
|
||||
if hasattr(base, ann_name):
|
||||
if base is generic_origin:
|
||||
# Don't error when "shadowing" of attributes in parametrized generics
|
||||
# Don't warn when "shadowing" of attributes in parametrized generics
|
||||
continue
|
||||
|
||||
if ann_name in dataclass_fields:
|
||||
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# on the class instance.
|
||||
continue
|
||||
|
||||
if ann_name not in annotations:
|
||||
# Don't warn when a field exists in a parent class but has not been defined in the current class
|
||||
continue
|
||||
|
||||
warnings.warn(
|
||||
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
|
||||
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
|
||||
f'"{base.__qualname__}"',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
try:
|
||||
default = getattr(cls, ann_name, PydanticUndefined)
|
||||
if default is PydanticUndefined:
|
||||
raise AttributeError
|
||||
except AttributeError:
|
||||
if ann_name in annotations:
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
if assigned_value is PydanticUndefined: # no assignment, just a plain annotation
|
||||
if ann_name in annotations or ann_name not in parent_fields_lookup:
|
||||
# field is either:
|
||||
# - present in the current model's annotations (and *not* from parent classes)
|
||||
# - not found on any base classes; this seems to be caused by fields bot getting
|
||||
# generated due to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a `FieldInfo` for this type hint, so we do this.
|
||||
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
# Store the original annotation that should be used to rebuild
|
||||
# the field info later:
|
||||
field_info._original_annotation = ann_type
|
||||
else:
|
||||
# if field has no default value and is not in __annotations__ this means that it is
|
||||
# defined in a base class and we can take it from there
|
||||
model_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for x in cls.__bases__[::-1]:
|
||||
model_fields_lookup.update(getattr(x, 'model_fields', {}))
|
||||
if ann_name in model_fields_lookup:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(model_fields_lookup[ann_name])
|
||||
else:
|
||||
# The field was not found on any base classes; this seems to be caused by fields not getting
|
||||
# generated thanks to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(parent_fields_lookup[ann_name])
|
||||
|
||||
else: # An assigned value is present (either the default value, or a `Field()` function)
|
||||
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
|
||||
if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default):
|
||||
# `assigned_value` was fetched using `getattr`, which triggers a call to `__get__`
|
||||
# for descriptors, so we do the same if the `= field(default=...)` form is used.
|
||||
# Note that we only do this for method descriptors for now, we might want to
|
||||
# extend this to any descriptor in the future (by simply checking for
|
||||
# `hasattr(assigned_value.default, '__get__')`).
|
||||
assigned_value.default = assigned_value.default.__get__(None, cls)
|
||||
|
||||
# The `from_annotated_attribute()` call below mutates the assigned `Field()`, so make a copy:
|
||||
original_assignment = (
|
||||
assigned_value._copy() if not evaluated and isinstance(assigned_value, FieldInfo_) else assigned_value
|
||||
)
|
||||
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS)
|
||||
# Store the original annotation and assignment value that should be used to rebuild the field info later.
|
||||
# Note that the assignment is always stored as the annotation might contain a type var that is later
|
||||
# parameterized with an unknown forward reference (and we'll need it to rebuild the field info):
|
||||
field_info._original_assignment = original_assignment
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
elif 'final' in field_info._qualifiers and not field_info.is_required():
|
||||
warnings.warn(
|
||||
f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a '
|
||||
'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you '
|
||||
f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[<type>] = <default>.`',
|
||||
category=PydanticDeprecatedSince211,
|
||||
# Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case:
|
||||
stacklevel=4,
|
||||
)
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
# attributes which are fields are removed from the class namespace:
|
||||
# 1. To match the behaviour of annotation-only fields
|
||||
# 2. To avoid false positives in the NameError check above
|
||||
@@ -201,81 +266,250 @@ def collect_model_fields( # noqa: C901
|
||||
# to make sure the decorators have already been built for this exact class
|
||||
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
|
||||
if ann_name in decorators.computed_fields:
|
||||
raise ValueError("you can't override a field with a computed field")
|
||||
raise TypeError(
|
||||
f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. '
|
||||
'This override with a computed_field is incompatible.'
|
||||
)
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
if field._complete:
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(cls, fields)
|
||||
return fields, class_vars
|
||||
|
||||
|
||||
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
|
||||
from ..fields import FieldInfo
|
||||
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
if not is_finalvar(type_):
|
||||
return False
|
||||
elif val is PydanticUndefined:
|
||||
return False
|
||||
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
args = getattr(ann_type, '__args__', None)
|
||||
if args:
|
||||
for anno_arg in args:
|
||||
if typing_objects.is_annotated(get_origin(anno_arg)):
|
||||
for anno_type_arg in _typing_extra.get_args(anno_arg):
|
||||
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
|
||||
warnings.warn(
|
||||
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
|
||||
UserWarning,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def rebuild_model_fields(
|
||||
cls: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) model fields by trying to reevaluate annotations.
|
||||
|
||||
This function should be called whenever a model with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the model fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other model's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
|
||||
if (assign := field_info._original_assignment) is PydanticUndefined:
|
||||
new_field = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS)
|
||||
else:
|
||||
new_field = FieldInfo_.from_annotated_attribute(ann, assign, _source=AnnotationSource.CLASS)
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
|
||||
|
||||
def collect_dataclass_fields(
|
||||
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
|
||||
cls: type[StandardDataclass],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
config_wrapper: ConfigWrapper | None = None,
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Collect the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: dataclass.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
Defaults to an empty instance.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
Returns:
|
||||
The dataclass fields.
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
|
||||
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
dataclass_fields = cls.__dataclass_fields__
|
||||
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
|
||||
if is_classvar(ann_type):
|
||||
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
|
||||
# although we do it manually as stdlib dataclasses already have annotations
|
||||
# collected in each class:
|
||||
for base in reversed(cls.__mro__):
|
||||
if not dataclasses.is_dataclass(base):
|
||||
continue
|
||||
|
||||
if not dataclass_field.init and dataclass_field.default_factory == dataclasses.MISSING:
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
if ann_name not in base.__dict__.get('__annotations__', {}):
|
||||
# `__dataclass_fields__`contains every field, even the ones from base classes.
|
||||
# Only collect the ones defined on `base`.
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo):
|
||||
if dataclass_field.default.init_var:
|
||||
# TODO: same note as above
|
||||
continue
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
|
||||
fields[ann_name] = field_info
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default is dataclasses.MISSING
|
||||
and dataclass_field.default_factory is dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo_):
|
||||
if dataclass_field.default.init_var:
|
||||
if dataclass_field.default.init is False:
|
||||
raise PydanticUserError(
|
||||
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
|
||||
code='clashing-init-and-init-var',
|
||||
)
|
||||
|
||||
# TODO: same note as above re validate_assignment
|
||||
continue
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field.default
|
||||
else:
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field
|
||||
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(
|
||||
getattr(cls, ann_name, field_info), FieldInfo_
|
||||
):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
# We don't pass any ns, as `field.annotation`
|
||||
# was already evaluated. TODO: is this method relevant?
|
||||
# Can't we juste use `_generics.replace_types`?
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper is not None and config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(
|
||||
cls,
|
||||
fields,
|
||||
# We can't rely on the (more reliable) frame inspection method
|
||||
# for stdlib dataclasses:
|
||||
use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'),
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def rebuild_dataclass_fields(
|
||||
cls: type[PydanticDataclass],
|
||||
*,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) dataclass fields by trying to reevaluate annotations.
|
||||
|
||||
This function should be called whenever a dataclass with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the dataclass fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other dataclass's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
new_field = FieldInfo_.from_annotated_attribute(
|
||||
ann,
|
||||
field_info._original_assignment,
|
||||
_source=AnnotationSource.DATACLASS,
|
||||
)
|
||||
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
|
||||
|
||||
def is_valid_field_name(name: str) -> bool:
|
||||
return not name.startswith('_')
|
||||
|
||||
|
||||
def is_valid_privateattr_name(name: str) -> bool:
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
|
||||
|
||||
def takes_validated_data_argument(
|
||||
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
|
||||
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
|
||||
"""Whether the provided default factory callable has a validated data parameter."""
|
||||
try:
|
||||
sig = signature(default_factory)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no data argument is present:
|
||||
return False
|
||||
|
||||
parameters = list(sig.parameters.values())
|
||||
|
||||
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -14,3 +15,9 @@ class PydanticRecursiveRef:
|
||||
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
|
||||
this class as the result of resolving a standard ForwardRef.
|
||||
"""
|
||||
|
||||
def __or__(self, other):
|
||||
return Union[self, other] # type: ignore
|
||||
|
||||
def __ror__(self, other):
|
||||
return Union[other, self] # type: ignore
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,17 +4,21 @@ import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from collections.abc import Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from itertools import zip_longest
|
||||
from types import prepare_class
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Annotated, Any, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typing_extensions
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
from ._core_utils import get_type_ref
|
||||
from ._forward_ref import PydanticRecursiveRef
|
||||
from ._typing_extra import TypeVarType, typing_base
|
||||
from ._utils import all_identical, is_model_class
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
@@ -23,7 +27,7 @@ if sys.version_info >= (3, 10):
|
||||
if TYPE_CHECKING:
|
||||
from ..main import BaseModel
|
||||
|
||||
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
|
||||
|
||||
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
|
||||
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
|
||||
@@ -34,43 +38,25 @@ GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
_LIMITED_DICT_SIZE = 100
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
...
|
||||
|
||||
else:
|
||||
class LimitedDict(dict[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
class LimitedDict(dict):
|
||||
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
super().__setitem__(__key, __value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for key in to_remove:
|
||||
del self[key]
|
||||
|
||||
def __class_getitem__(cls, *args: Any) -> Any:
|
||||
# to avoid errors with 3.7
|
||||
return cls
|
||||
def __setitem__(self, key: KT, value: VT, /) -> None:
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for k in to_remove:
|
||||
del self[k]
|
||||
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -108,13 +94,13 @@ else:
|
||||
# and discover later on that we need to re-add all this infrastructure...
|
||||
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
|
||||
|
||||
_GENERIC_TYPES_CACHE = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
|
||||
|
||||
|
||||
class PydanticGenericMetadata(typing_extensions.TypedDict):
|
||||
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
|
||||
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
|
||||
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
|
||||
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
|
||||
|
||||
|
||||
def create_generic_submodel(
|
||||
@@ -171,7 +157,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
depth: The depth to get the frame.
|
||||
|
||||
Returns:
|
||||
A tuple contains `module_nam` and `called_globally`.
|
||||
A tuple contains `module_name` and `called_globally`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function is not called inside a function.
|
||||
@@ -189,7 +175,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
DictValues: type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
|
||||
|
||||
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
|
||||
@@ -222,7 +208,7 @@ def get_origin(v: Any) -> Any:
|
||||
return typing_extensions.get_origin(v)
|
||||
|
||||
|
||||
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
|
||||
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
|
||||
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
|
||||
"""
|
||||
@@ -235,11 +221,11 @@ def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
|
||||
# So it is safe to access cls.__args__ and origin.__parameters__
|
||||
args: tuple[Any, ...] = cls.__args__ # type: ignore
|
||||
parameters: tuple[TypeVarType, ...] = origin.__parameters__
|
||||
parameters: tuple[TypeVar, ...] = origin.__parameters__
|
||||
return dict(zip(parameters, args))
|
||||
|
||||
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
|
||||
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
|
||||
with the `replace_types` function.
|
||||
|
||||
@@ -251,10 +237,13 @@ def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | Non
|
||||
generic_metadata = cls.__pydantic_generic_metadata__
|
||||
origin = generic_metadata['origin']
|
||||
args = generic_metadata['args']
|
||||
if not args:
|
||||
# No need to go into `iter_contained_typevars`:
|
||||
return {}
|
||||
return dict(zip(iter_contained_typevars(origin), args))
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
Args:
|
||||
@@ -266,13 +255,13 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from typing import List, Tuple, Union
|
||||
```python
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic._internal._generics import replace_types
|
||||
|
||||
replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
#> Tuple[int, Union[List[int], float]]
|
||||
replace_types(tuple[str, Union[List[str], float]], {str: int})
|
||||
#> tuple[int, Union[List[int], float]]
|
||||
```
|
||||
"""
|
||||
if not type_map:
|
||||
@@ -281,25 +270,25 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
if typing_objects.is_annotated(origin_type):
|
||||
annotated_type, *annotations = type_args
|
||||
annotated = replace_types(annotated_type, type_map)
|
||||
for annotation in annotations:
|
||||
annotated = typing_extensions.Annotated[annotated, annotation]
|
||||
return annotated
|
||||
annotated_type = replace_types(annotated_type, type_map)
|
||||
# TODO remove parentheses when we drop support for Python 3.10:
|
||||
return Annotated[(annotated_type, *annotations)]
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
# Having type args is a good indicator that this is a typing special form
|
||||
# instance or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and isinstance(type_, _typing_extra.typing_base)
|
||||
and not isinstance(origin_type, _typing_extra.typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
@@ -307,11 +296,24 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
|
||||
if is_union_origin(origin_type):
|
||||
if any(typing_objects.is_any(arg) for arg in resolved_type_args):
|
||||
# `Any | T` ~ `Any`:
|
||||
resolved_type_args = (Any,)
|
||||
# `Never | T` ~ `T`:
|
||||
resolved_type_args = tuple(
|
||||
arg
|
||||
for arg in resolved_type_args
|
||||
if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
|
||||
)
|
||||
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
return origin_type[resolved_type_args]
|
||||
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
|
||||
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
@@ -327,8 +329,8 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if isinstance(type_, list):
|
||||
resolved_list = [replace_types(element, type_map) for element in type_]
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
@@ -338,49 +340,57 @@ def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
|
||||
"""Checks if the type, or any of its arbitrary nested args, satisfy
|
||||
`isinstance(<type>, isinstance_target)`.
|
||||
"""
|
||||
if isinstance(type_, isinstance_target):
|
||||
return True
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return has_instance_in_type(annotated_type, isinstance_target)
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
|
||||
return True
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
|
||||
if any(has_instance_in_type(element, isinstance_target) for element in type_):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
|
||||
"""Check the generic model parameters count is equal.
|
||||
|
||||
Args:
|
||||
cls: The generic model.
|
||||
parameters: A tuple of passed parameters to the generic model.
|
||||
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
|
||||
"""Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
|
||||
|
||||
Raises:
|
||||
TypeError: If the passed parameters count is not equal to generic model parameters count.
|
||||
TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
|
||||
|
||||
Example:
|
||||
```python {test="skip" lint="skip"}
|
||||
class Model[T, U, V = int](BaseModel): ...
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes))
|
||||
#> {T: str, U: bytes, V: int}
|
||||
|
||||
map_generic_model_arguments(Model, (str,))
|
||||
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes, int, complex))
|
||||
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
|
||||
```
|
||||
|
||||
Note:
|
||||
This function is analogous to the private `typing._check_generic_specialization` function.
|
||||
"""
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__pydantic_generic_metadata__['parameters'])
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
|
||||
parameters = cls.__pydantic_generic_metadata__['parameters']
|
||||
expected_len = len(parameters)
|
||||
typevars_map: dict[TypeVar, Any] = {}
|
||||
|
||||
_missing = object()
|
||||
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
|
||||
if parameter is _missing:
|
||||
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
|
||||
|
||||
if argument is _missing:
|
||||
param = typing.cast(TypeVar, parameter)
|
||||
try:
|
||||
has_default = param.has_default()
|
||||
except AttributeError:
|
||||
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
|
||||
has_default = False
|
||||
if has_default:
|
||||
# The default might refer to other type parameters. For an example, see:
|
||||
# https://typing.readthedocs.io/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
|
||||
typevars_map[param] = replace_types(param.__default__, typevars_map)
|
||||
else:
|
||||
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
|
||||
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
|
||||
else:
|
||||
param = typing.cast(TypeVar, parameter)
|
||||
typevars_map[param] = argument
|
||||
|
||||
return typevars_map
|
||||
|
||||
|
||||
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
|
||||
@@ -411,7 +421,8 @@ def generic_recursion_self_type(
|
||||
yield self_type
|
||||
else:
|
||||
previously_seen_type_refs.add(type_ref)
|
||||
yield None
|
||||
yield
|
||||
previously_seen_type_refs.remove(type_ref)
|
||||
finally:
|
||||
if token:
|
||||
_generic_recursion_cache.reset(token)
|
||||
@@ -442,14 +453,24 @@ def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any)
|
||||
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
|
||||
equal.
|
||||
"""
|
||||
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if generic_types_cache is None:
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
return generic_types_cache.get(_early_cache_key(parent, typevar_values))
|
||||
|
||||
|
||||
def get_cached_generic_type_late(
|
||||
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> type[BaseModel] | None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
|
||||
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
|
||||
if cached is not None:
|
||||
set_cached_generic_type(parent, typevar_values, cached, origin, args)
|
||||
return cached
|
||||
@@ -465,11 +486,17 @@ def set_cached_generic_type(
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
|
||||
two different keys.
|
||||
"""
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
|
||||
if len(typevar_values) == 1:
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
if origin and args:
|
||||
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
|
||||
|
||||
def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
@@ -490,7 +517,7 @@ def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
for value in typevar_values:
|
||||
args_data.append(_union_orderings_key(value))
|
||||
return tuple(args_data)
|
||||
elif typing_extensions.get_origin(typevar_values) is typing.Union:
|
||||
elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
|
||||
return get_args(typevar_values)
|
||||
else:
|
||||
return ()
|
||||
|
||||
27
venv/lib/python3.12/site-packages/pydantic/_internal/_git.py
Normal file
27
venv/lib/python3.12/site-packages/pydantic/_internal/_git.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def is_git_repo(dir: Path) -> bool:
|
||||
"""Is the given directory version-controlled with git?"""
|
||||
return dir.joinpath('.git').exists()
|
||||
|
||||
|
||||
def have_git() -> bool: # pragma: no cover
|
||||
"""Can we run the git executable?"""
|
||||
try:
|
||||
subprocess.check_output(['git', '--help'])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def git_revision(dir: Path) -> str:
|
||||
"""Get the SHA-1 of the HEAD of a git repository."""
|
||||
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()
|
||||
@@ -0,0 +1,20 @@
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_base_model() -> type['BaseModel']:
|
||||
from pydantic import BaseModel
|
||||
|
||||
return BaseModel
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_field_info() -> type['FieldInfo']:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
return FieldInfo
|
||||
@@ -1,7 +1,4 @@
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
dataclass_kwargs: Dict[str, Any]
|
||||
|
||||
# `slots` is available on Python >= 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@@ -1,42 +1,57 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import annotated_types as at
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from . import _validators
|
||||
from ._fields import PydanticGeneralMetadata, PydanticMetadata
|
||||
from ._fields import PydanticMetadata
|
||||
from ._import_utils import import_cached_field_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..annotated_handlers import GetJsonSchemaHandler
|
||||
|
||||
pass
|
||||
|
||||
STRICT = {'strict'}
|
||||
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
FAIL_FAST = {'fail_fast'}
|
||||
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
|
||||
ALLOW_INF_NAN = {'allow_inf_nan'}
|
||||
|
||||
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
|
||||
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
STR_CONSTRAINTS = {
|
||||
*LENGTH_CONSTRAINTS,
|
||||
*STRICT,
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'pattern',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
|
||||
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
|
||||
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
BOOL_CONSTRAINTS = STRICT
|
||||
UUID_CONSTRAINTS = STRICT
|
||||
|
||||
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
LAX_OR_STRICT_CONSTRAINTS = STRICT
|
||||
ENUM_CONSTRAINTS = STRICT
|
||||
COMPLEX_CONSTRAINTS = STRICT
|
||||
|
||||
UNION_CONSTRAINTS = {'union_mode'}
|
||||
URL_CONSTRAINTS = {
|
||||
@@ -53,54 +68,33 @@ SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT
|
||||
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
|
||||
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
|
||||
for constraint in STR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
|
||||
for constraint in BYTES_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
|
||||
for constraint in LIST_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
|
||||
for constraint in TUPLE_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
|
||||
for constraint in SET_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
|
||||
for constraint in DICT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
|
||||
for constraint in GENERATOR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
|
||||
for constraint in FLOAT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
|
||||
for constraint in INT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
|
||||
for constraint in DATE_TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
|
||||
for constraint in TIMEDELTA_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
|
||||
for constraint in TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
|
||||
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
|
||||
for constraint in UNION_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
|
||||
for constraint in URL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
|
||||
for constraint in BOOL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
|
||||
|
||||
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
|
||||
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
|
||||
(BYTES_CONSTRAINTS, ('bytes',)),
|
||||
(LIST_CONSTRAINTS, ('list',)),
|
||||
(TUPLE_CONSTRAINTS, ('tuple',)),
|
||||
(SET_CONSTRAINTS, ('set', 'frozenset')),
|
||||
(DICT_CONSTRAINTS, ('dict',)),
|
||||
(GENERATOR_CONSTRAINTS, ('generator',)),
|
||||
(FLOAT_CONSTRAINTS, ('float',)),
|
||||
(INT_CONSTRAINTS, ('int',)),
|
||||
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
|
||||
# TODO: this is a bit redundant, we could probably avoid some of these
|
||||
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
|
||||
(UNION_CONSTRAINTS, ('union',)),
|
||||
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
|
||||
(BOOL_CONSTRAINTS, ('bool',)),
|
||||
(UUID_CONSTRAINTS, ('uuid',)),
|
||||
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
|
||||
(ENUM_CONSTRAINTS, ('enum',)),
|
||||
(DECIMAL_CONSTRAINTS, ('decimal',)),
|
||||
(COMPLEX_CONSTRAINTS, ('complex',)),
|
||||
]
|
||||
|
||||
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
|
||||
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
js_schema = handler(s)
|
||||
js_schema.update(f())
|
||||
return js_schema
|
||||
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if 'pydantic_js_functions' in s:
|
||||
metadata['pydantic_js_functions'].append(update_js_schema)
|
||||
else:
|
||||
metadata['pydantic_js_functions'] = [update_js_schema]
|
||||
else:
|
||||
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
|
||||
for constraints, schemas in constraint_schema_pairings:
|
||||
for c in constraints:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
|
||||
|
||||
|
||||
def as_jsonable_value(v: Any) -> Any:
|
||||
@@ -119,7 +113,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
An iterable of expanded annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
```python
|
||||
from annotated_types import Ge, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
|
||||
@@ -128,7 +122,9 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
#> [Ge(ge=4), MinLen(min_length=5)]
|
||||
```
|
||||
"""
|
||||
from pydantic.fields import FieldInfo # circular import
|
||||
import annotated_types as at
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, at.GroupedMetadata):
|
||||
@@ -147,6 +143,28 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
yield annotation
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_at_to_constraint_map() -> dict[type, str]:
|
||||
"""Return a mapping of annotated types to constraints.
|
||||
|
||||
Normally, we would define a mapping like this in the module scope, but we can't do that
|
||||
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
|
||||
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
|
||||
so we use this function to cache the result.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
return {
|
||||
at.Gt: 'gt',
|
||||
at.Ge: 'ge',
|
||||
at.Lt: 'lt',
|
||||
at.Le: 'le',
|
||||
at.MultipleOf: 'multiple_of',
|
||||
at.MinLen: 'min_length',
|
||||
at.MaxLen: 'max_length',
|
||||
}
|
||||
|
||||
|
||||
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
|
||||
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
|
||||
Otherwise return `None`.
|
||||
@@ -166,14 +184,37 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
Raises:
|
||||
PydanticCustomError: If `Predicate` fails.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
|
||||
|
||||
schema = schema.copy()
|
||||
schema_update, other_metadata = collect_known_metadata([annotation])
|
||||
schema_type = schema['type']
|
||||
|
||||
chain_schema_constraints: set[str] = {
|
||||
'pattern',
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
chain_schema_steps: list[CoreSchema] = []
|
||||
|
||||
for constraint, value in schema_update.items():
|
||||
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
|
||||
|
||||
# if it becomes necessary to handle more than one constraint
|
||||
# in this recursive case with function-after or function-wrap, we should refactor
|
||||
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
|
||||
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
|
||||
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
|
||||
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
|
||||
return schema
|
||||
|
||||
# if we're allowed to apply constraint directly to the schema, like le to int, do that
|
||||
if schema_type in allowed_schemas:
|
||||
if constraint == 'union_mode' and schema_type == 'union':
|
||||
schema['mode'] = value # type: ignore # schema is UnionSchema
|
||||
@@ -181,145 +222,109 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
schema[constraint] = value
|
||||
continue
|
||||
|
||||
if constraint == 'allow_inf_nan' and value is False:
|
||||
return cs.no_info_after_validator_function(
|
||||
_validators.forbid_inf_nan_check,
|
||||
schema,
|
||||
# else, apply a function after validator to the schema to enforce the corresponding constraint
|
||||
if constraint in chain_schema_constraints:
|
||||
|
||||
def _apply_constraint_with_incompatibility_info(
|
||||
value: Any, handler: cs.ValidatorFunctionWrapHandler
|
||||
) -> Any:
|
||||
try:
|
||||
x = handler(value)
|
||||
except ValidationError as ve:
|
||||
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
|
||||
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
|
||||
# with a cryptic 'string_type' error coming from the string validator,
|
||||
# that we'd rather express as a constraint incompatibility error (TypeError)
|
||||
# Annotated[list[int], Field(pattern='abc')]
|
||||
if 'type' in ve.errors()[0]['type']:
|
||||
raise TypeError(
|
||||
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
|
||||
)
|
||||
raise ve
|
||||
return x
|
||||
|
||||
chain_schema_steps.append(
|
||||
cs.no_info_wrap_validator_function(
|
||||
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
|
||||
)
|
||||
)
|
||||
elif constraint == 'pattern':
|
||||
# insert a str schema to make sure the regex engine matches
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(pattern=value),
|
||||
]
|
||||
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
|
||||
if constraint in LENGTH_CONSTRAINTS:
|
||||
inner_schema = schema
|
||||
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
|
||||
inner_schema = inner_schema['schema'] # type: ignore
|
||||
inner_schema_type = inner_schema['type']
|
||||
if inner_schema_type == 'list' or (
|
||||
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
|
||||
):
|
||||
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
|
||||
else:
|
||||
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
|
||||
else:
|
||||
js_constraint_key = constraint
|
||||
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
|
||||
)
|
||||
elif constraint == 'gt':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
|
||||
return s
|
||||
elif constraint == 'ge':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'lt':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'le':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'multiple_of':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'max_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'strip_whitespace':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(strip_whitespace=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_lower':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_lower=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_upper':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_upper=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'max_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
metadata = schema.get('metadata', {})
|
||||
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
|
||||
metadata['pydantic_js_updates'] = {
|
||||
**existing_json_schema_updates,
|
||||
**{js_constraint_key: as_jsonable_value(value)},
|
||||
}
|
||||
else:
|
||||
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
|
||||
schema['metadata'] = metadata
|
||||
elif constraint == 'allow_inf_nan' and value is False:
|
||||
schema = cs.no_info_after_validator_function(
|
||||
forbid_inf_nan_check,
|
||||
schema,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
|
||||
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
|
||||
# Most constraint errors are caught at runtime during attempted application
|
||||
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
|
||||
|
||||
for annotation in other_metadata:
|
||||
if isinstance(annotation, at.Gt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=annotation.gt),
|
||||
schema,
|
||||
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
|
||||
if validator is None:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(validator, {constraint: getattr(annotation, constraint)}), schema
|
||||
)
|
||||
elif isinstance(annotation, at.Ge):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Lt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=annotation.lt),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Le):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=annotation.le),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Predicate):
|
||||
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
|
||||
continue
|
||||
elif isinstance(annotation, (at.Predicate, at.Not)):
|
||||
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
|
||||
# annotation.func may also raise an exception, let it pass through
|
||||
if not annotation.func(v):
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name}failed', # type: ignore
|
||||
)
|
||||
if isinstance(annotation, at.Predicate): # noqa: B023
|
||||
if not predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
else:
|
||||
if predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'not_operation_failed',
|
||||
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
return cs.no_info_after_validator_function(val_func, schema)
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
schema = cs.no_info_after_validator_function(val_func, schema)
|
||||
else:
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
if chain_schema_steps:
|
||||
chain_schema_steps = [schema] + chain_schema_steps
|
||||
return cs.chain_schema(chain_schema_steps)
|
||||
|
||||
return schema
|
||||
|
||||
@@ -334,7 +339,7 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
A tuple contains a dict of known metadata and a list of unknown annotations.
|
||||
|
||||
Example:
|
||||
```py
|
||||
```python
|
||||
from annotated_types import Gt, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
||||
@@ -347,29 +352,15 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
|
||||
res: dict[str, Any] = {}
|
||||
remaining: list[Any] = []
|
||||
|
||||
for annotation in annotations:
|
||||
# Do we really want to consume any `BaseMetadata`?
|
||||
# It does let us give a better error when there is an annotation that doesn't apply
|
||||
# But it seems dangerous!
|
||||
if isinstance(annotation, PydanticGeneralMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
elif isinstance(annotation, PydanticMetadata):
|
||||
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
|
||||
if isinstance(annotation, PydanticMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
res.update({'min_length': annotation.min_length})
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
res.update({'max_length': annotation.max_length})
|
||||
elif isinstance(annotation, at.Gt):
|
||||
res.update({'gt': annotation.gt})
|
||||
elif isinstance(annotation, at.Ge):
|
||||
res.update({'ge': annotation.ge})
|
||||
elif isinstance(annotation, at.Lt):
|
||||
res.update({'lt': annotation.lt})
|
||||
elif isinstance(annotation, at.Le):
|
||||
res.update({'le': annotation.le})
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
res.update({'multiple_of': annotation.multiple_of})
|
||||
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
res[constraint] = getattr(annotation, constraint)
|
||||
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
|
||||
# also support PydanticMetadata classes being used without initialisation,
|
||||
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
|
||||
|
||||
@@ -1,18 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
|
||||
|
||||
from pydantic_core import SchemaSerializer, SchemaValidator
|
||||
from typing_extensions import Literal
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
|
||||
|
||||
from ..errors import PydanticErrorCodes, PydanticUserError
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataclasses import PydanticDataclass
|
||||
from ..main import BaseModel
|
||||
from ..type_adapter import TypeAdapter
|
||||
|
||||
|
||||
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
|
||||
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockCoreSchema(Mapping[str, Any]):
|
||||
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
self._built_memo: CoreSchema | None = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._get_built().__getitem__(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._get_built().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return self._get_built().__iter__()
|
||||
|
||||
def _get_built(self) -> CoreSchema:
|
||||
if self._built_memo is not None:
|
||||
return self._built_memo
|
||||
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
self._built_memo = schema
|
||||
return schema
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> CoreSchema | None:
|
||||
self._built_memo = None
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
return schema
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
|
||||
|
||||
class MockValSer(Generic[ValSer]):
|
||||
@@ -56,63 +109,120 @@ class MockValSer(Generic[ValSer]):
|
||||
return None
|
||||
|
||||
|
||||
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
|
||||
def set_type_adapter_mocks(adapter: TypeAdapter) -> None:
|
||||
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
adapter: The type adapter instance to set the mocks on
|
||||
"""
|
||||
type_repr = str(adapter._type)
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls_name}.model_rebuild()`.'
|
||||
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
|
||||
f' then call `.rebuild()` on the instance.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_validator() -> SchemaValidator | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(adapter)
|
||||
return None
|
||||
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
return handler
|
||||
|
||||
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
|
||||
)
|
||||
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_validator,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
|
||||
)
|
||||
|
||||
def attempt_rebuild_serializer() -> SchemaSerializer | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_serializer__
|
||||
else:
|
||||
return None
|
||||
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_serializer,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mock_validator(cls: type[PydanticDataclass], cls_name: str, undefined_name: str) -> None:
|
||||
def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls.__name__}.model_rebuild()`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild() -> SchemaValidator | None:
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
return None
|
||||
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild,
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
undefined_type_error_message = (
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
)
|
||||
|
||||
@@ -1,58 +1,54 @@
|
||||
"""Private logic for creating models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import builtins
|
||||
import operator
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta
|
||||
from functools import partial
|
||||
from functools import cache, partial, wraps
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Generic, Mapping
|
||||
from typing import Any, Callable, Generic, Literal, NoReturn, cast
|
||||
|
||||
from pydantic_core import PydanticUndefined, SchemaSerializer
|
||||
from typing_extensions import dataclass_transform, deprecated
|
||||
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
|
||||
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
|
||||
from ._config import ConfigWrapper
|
||||
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
|
||||
from ._decorators import (
|
||||
ComputedFieldInfo,
|
||||
DecoratorInfos,
|
||||
PydanticDescriptorProxy,
|
||||
get_attribute_from_bases,
|
||||
)
|
||||
from ._discriminated_union import apply_discriminators
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._mock_val_ser import MockValSer, set_model_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
|
||||
from ._utils import ClassAttribute, is_valid_identifier
|
||||
from ._validate_call import ValidateCallWrapper
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._mock_val_ser import set_model_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._typing_extra import (
|
||||
_make_forward_ref,
|
||||
eval_type_backport,
|
||||
is_classvar_annotation,
|
||||
parent_frame_namespace,
|
||||
)
|
||||
from ._utils import LazyClassAttribute, SafeGetItemProxy
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from ..fields import Field as PydanticModelField
|
||||
from ..fields import FieldInfo, ModelPrivateAttr
|
||||
from ..fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
from ..main import BaseModel
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
PydanticModelField = object()
|
||||
PydanticModelPrivateAttr = object()
|
||||
|
||||
|
||||
IGNORED_TYPES: tuple[Any, ...] = (
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
ValidateCallWrapper,
|
||||
)
|
||||
object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@@ -69,7 +65,17 @@ class _ModelNamespaceDict(dict):
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
|
||||
def NoInitField(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
) -> Any:
|
||||
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
|
||||
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
|
||||
synthesizing the `__init__` signature.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
|
||||
class ModelMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcs,
|
||||
@@ -78,6 +84,7 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace: dict[str, Any],
|
||||
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
|
||||
__pydantic_reset_parent_namespace__: bool = True,
|
||||
_create_model_module: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Metaclass for creating Pydantic models.
|
||||
@@ -88,6 +95,7 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
__pydantic_generic_metadata__: Metadata for generic models.
|
||||
__pydantic_reset_parent_namespace__: Reset parent namespace.
|
||||
_create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
**kwargs: Catch-all for any other keyword arguments.
|
||||
|
||||
Returns:
|
||||
@@ -104,17 +112,18 @@ class ModelMetaclass(ABCMeta):
|
||||
private_attributes = inspect_namespace(
|
||||
namespace, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
)
|
||||
if private_attributes:
|
||||
if private_attributes or base_private_attributes:
|
||||
original_model_post_init = get_model_post_init(namespace, bases)
|
||||
if original_model_post_init is not None:
|
||||
# if there are private_attributes and a model_post_init function, we handle both
|
||||
|
||||
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
|
||||
@wraps(original_model_post_init)
|
||||
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
init_private_attributes(self, __context)
|
||||
original_model_post_init(self, __context)
|
||||
init_private_attributes(self, context)
|
||||
original_model_post_init(self, context)
|
||||
|
||||
namespace['model_post_init'] = wrapped_model_post_init
|
||||
else:
|
||||
@@ -123,15 +132,25 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace['__class_vars__'] = class_vars
|
||||
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
|
||||
|
||||
if config_wrapper.frozen:
|
||||
set_default_hash_func(namespace, bases)
|
||||
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
|
||||
BaseModel_ = import_cached_base_model()
|
||||
|
||||
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
|
||||
|
||||
from ..main import BaseModel
|
||||
mro = cls.__mro__
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
|
||||
warnings.warn(
|
||||
GenericBeforeBaseModelWarning(
|
||||
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
|
||||
'for pydantic generics to work properly.'
|
||||
),
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
|
||||
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
|
||||
cls.__pydantic_post_init__ = (
|
||||
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
|
||||
)
|
||||
|
||||
cls.__pydantic_setattr_handlers__ = {}
|
||||
|
||||
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
|
||||
|
||||
@@ -142,22 +161,40 @@ class ModelMetaclass(ABCMeta):
|
||||
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
parameters = getattr(cls, '__parameters__', None) or parent_parameters
|
||||
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
|
||||
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
from ..root_model import RootModelRootType
|
||||
|
||||
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
|
||||
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
|
||||
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
|
||||
# RootModel with the generic type identifiers being used. Ex:
|
||||
# class MyModel(RootModel, Generic[T]):
|
||||
# root: T
|
||||
# Should instead just be:
|
||||
# class MyModel(RootModel[T]):
|
||||
# root: T
|
||||
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
|
||||
error_message = (
|
||||
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
|
||||
f'{parameters_str} in its parameters. '
|
||||
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
|
||||
)
|
||||
else:
|
||||
combined_parameters = parent_parameters + missing_parameters
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
raise TypeError(error_message)
|
||||
|
||||
cls.__pydantic_generic_metadata__ = {
|
||||
@@ -175,29 +212,55 @@ class ModelMetaclass(ABCMeta):
|
||||
|
||||
if __pydantic_reset_parent_namespace__:
|
||||
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
|
||||
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
if isinstance(parent_namespace, dict):
|
||||
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
|
||||
|
||||
types_namespace = get_cls_types_namespace(cls, parent_namespace)
|
||||
set_model_fields(cls, bases, config_wrapper, types_namespace)
|
||||
complete_model_class(
|
||||
cls,
|
||||
cls_name,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
types_namespace=types_namespace,
|
||||
)
|
||||
ns_resolver = NsResolver(parent_namespace=parent_namespace)
|
||||
|
||||
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
|
||||
|
||||
# This is also set in `complete_model_class()`, after schema gen because they are recreated.
|
||||
# We set them here as well for backwards compatibility:
|
||||
cls.__pydantic_computed_fields__ = {
|
||||
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
|
||||
}
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
# TODO we can also stop there if `__pydantic_fields_complete__` is False.
|
||||
# However, `set_model_fields()` is currently lenient and we don't have access to the `NameError`.
|
||||
# (which is useful as we can provide the name in the error message: `set_model_mock(cls, e.name)`)
|
||||
set_model_mocks(cls)
|
||||
else:
|
||||
# Any operation that requires accessing the field infos instances should be put inside
|
||||
# `complete_model_class()`:
|
||||
complete_model_class(
|
||||
cls,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
ns_resolver=ns_resolver,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
|
||||
if config_wrapper.frozen and '__hash__' not in namespace:
|
||||
set_default_hash_func(cls, bases)
|
||||
|
||||
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
|
||||
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
|
||||
# only hit for _proper_ subclasses of BaseModel
|
||||
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
|
||||
return cls
|
||||
else:
|
||||
# this is the BaseModel class itself being created, no logic required
|
||||
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
|
||||
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
|
||||
namespace.pop(
|
||||
instance_slot,
|
||||
None, # In case the metaclass is used with a class other than `BaseModel`.
|
||||
)
|
||||
namespace.get('__annotations__', {}).clear()
|
||||
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
|
||||
|
||||
if not typing.TYPE_CHECKING:
|
||||
if not typing.TYPE_CHECKING: # pragma: no branch
|
||||
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
@@ -205,30 +268,29 @@ class ModelMetaclass(ABCMeta):
|
||||
private_attributes = self.__dict__.get('__private_attributes__')
|
||||
if private_attributes and item in private_attributes:
|
||||
return private_attributes[item]
|
||||
if item == '__pydantic_core_schema__':
|
||||
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
|
||||
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
|
||||
if isinstance(maybe_mock_validator, MockValSer):
|
||||
rebuilt_validator = maybe_mock_validator.rebuild()
|
||||
if rebuilt_validator is not None:
|
||||
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
|
||||
return getattr(self, '__pydantic_core_schema__')
|
||||
raise AttributeError(item)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> Mapping[str, object]:
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
|
||||
return _ModelNamespaceDict()
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Avoid calling ABC _abc_instancecheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_decorators__') and super().__instancecheck__(instance)
|
||||
|
||||
def __subclasscheck__(self, subclass: type[Any]) -> bool:
|
||||
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
|
||||
return hasattr(subclass, '__pydantic_decorators__') and super().__subclasscheck__(subclass)
|
||||
|
||||
@staticmethod
|
||||
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
field_names: set[str] = set()
|
||||
class_vars: set[str] = set()
|
||||
@@ -236,35 +298,57 @@ class ModelMetaclass(ABCMeta):
|
||||
for base in bases:
|
||||
if issubclass(base, BaseModel) and base is not BaseModel:
|
||||
# model_fields might not be defined yet in the case of generics, so we use getattr here:
|
||||
field_names.update(getattr(base, 'model_fields', {}).keys())
|
||||
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
|
||||
class_vars.update(base.__class_vars__)
|
||||
private_attributes.update(base.__private_attributes__)
|
||||
return field_names, class_vars, private_attributes
|
||||
|
||||
@property
|
||||
@deprecated(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
|
||||
def __fields__(self) -> dict[str, FieldInfo]:
|
||||
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
|
||||
return self.model_fields # type: ignore
|
||||
warnings.warn(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return getattr(self, '__pydantic_fields__', {})
|
||||
|
||||
@property
|
||||
def __pydantic_fields_complete__(self) -> bool:
|
||||
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
|
||||
|
||||
This is a private attribute, not meant to be used outside Pydantic.
|
||||
"""
|
||||
if not hasattr(self, '__pydantic_fields__'):
|
||||
return False
|
||||
|
||||
field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return all(field_info._complete for field_info in field_infos.values())
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
attributes = list(super().__dir__())
|
||||
if '__fields__' in attributes:
|
||||
attributes.remove('__fields__')
|
||||
return attributes
|
||||
|
||||
|
||||
def init_private_attributes(self: BaseModel, __context: Any) -> None:
|
||||
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
|
||||
"""This function is meant to behave like a BaseModel method to initialise private attributes.
|
||||
|
||||
It takes context as an argument since that's what pydantic-core passes when calling it.
|
||||
|
||||
Args:
|
||||
self: The BaseModel instance.
|
||||
__context: The context.
|
||||
context: The context.
|
||||
"""
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
if getattr(self, '__pydantic_private__', None) is None:
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
|
||||
|
||||
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
|
||||
@@ -272,7 +356,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...])
|
||||
if 'model_post_init' in namespace:
|
||||
return namespace['model_post_init']
|
||||
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
|
||||
if model_post_init is not BaseModel.model_post_init:
|
||||
@@ -305,7 +389,11 @@ def inspect_namespace( # noqa C901
|
||||
- If a field does not have a type annotation.
|
||||
- If a field on base class was overridden by a non-annotated attribute.
|
||||
"""
|
||||
all_ignored_types = ignored_types + IGNORED_TYPES
|
||||
from ..fields import ModelPrivateAttr, PrivateAttr
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
all_ignored_types = ignored_types + default_ignored_types()
|
||||
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
@@ -315,11 +403,12 @@ def inspect_namespace( # noqa C901
|
||||
|
||||
ignored_names: set[str] = set()
|
||||
for var_name, value in list(namespace.items()):
|
||||
if var_name == 'model_config':
|
||||
if var_name == 'model_config' or var_name == '__pydantic_extra__':
|
||||
continue
|
||||
elif (
|
||||
isinstance(value, type)
|
||||
and value.__module__ == namespace['__module__']
|
||||
and '__qualname__' in namespace
|
||||
and value.__qualname__.startswith(namespace['__qualname__'])
|
||||
):
|
||||
# `value` is a nested type defined in this namespace; don't error
|
||||
@@ -350,8 +439,8 @@ def inspect_namespace( # noqa C901
|
||||
elif var_name.startswith('__'):
|
||||
continue
|
||||
elif is_valid_privateattr_name(var_name):
|
||||
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = PrivateAttr(default=value)
|
||||
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
|
||||
del namespace[var_name]
|
||||
elif var_name in base_class_vars:
|
||||
continue
|
||||
@@ -368,8 +457,8 @@ def inspect_namespace( # noqa C901
|
||||
)
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f"A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a "
|
||||
f"type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this "
|
||||
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
|
||||
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
|
||||
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
|
||||
code='model-field-missing-annotation',
|
||||
)
|
||||
@@ -379,45 +468,82 @@ def inspect_namespace( # noqa C901
|
||||
is_valid_privateattr_name(ann_name)
|
||||
and ann_name not in private_attributes
|
||||
and ann_name not in ignored_names
|
||||
and not is_classvar(ann_type)
|
||||
# This condition can be a false negative when `ann_type` is stringified,
|
||||
# but it is handled in most cases in `set_model_fields`:
|
||||
and not is_classvar_annotation(ann_type)
|
||||
and ann_type not in all_ignored_types
|
||||
and getattr(ann_type, '__module__', None) != 'functools'
|
||||
):
|
||||
if isinstance(ann_type, str):
|
||||
# Walking up the frames to get the module namespace where the model is defined
|
||||
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
|
||||
frame = sys._getframe(2)
|
||||
if frame is not None:
|
||||
try:
|
||||
ann_type = eval_type_backport(
|
||||
_make_forward_ref(ann_type, is_argument=False, is_class=True),
|
||||
globalns=frame.f_globals,
|
||||
localns=frame.f_locals,
|
||||
)
|
||||
except (NameError, TypeError):
|
||||
pass
|
||||
|
||||
if typing_objects.is_annotated(get_origin(ann_type)):
|
||||
_, *metadata = get_args(ann_type)
|
||||
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
|
||||
if private_attr is not None:
|
||||
private_attributes[ann_name] = private_attr
|
||||
continue
|
||||
private_attributes[ann_name] = PrivateAttr()
|
||||
|
||||
return private_attributes
|
||||
|
||||
|
||||
def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
|
||||
if '__hash__' in namespace:
|
||||
return
|
||||
|
||||
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
|
||||
base_hash_func = get_attribute_from_bases(bases, '__hash__')
|
||||
if base_hash_func in {None, object.__hash__}:
|
||||
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
|
||||
new_hash_func = make_hash_func(cls)
|
||||
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
|
||||
# If `__hash__` is some default, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel.
|
||||
# It may be `object.__hash__` if there is another
|
||||
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
|
||||
def hash_func(self: Any) -> int:
|
||||
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
|
||||
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
|
||||
# In the last case we still need a new hash function to account for new `model_fields`.
|
||||
cls.__hash__ = new_hash_func
|
||||
|
||||
namespace['__hash__'] = hash_func
|
||||
|
||||
def make_hash_func(cls: type[BaseModel]) -> Any:
|
||||
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
|
||||
|
||||
def hash_func(self: Any) -> int:
|
||||
try:
|
||||
return hash(getter(self.__dict__))
|
||||
except KeyError:
|
||||
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
|
||||
# all model fields, which is how we can get here.
|
||||
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
|
||||
# and wrapping it in a `try` doesn't slow things down much in the common case.
|
||||
return hash(getter(SafeGetItemProxy(self.__dict__)))
|
||||
|
||||
return hash_func
|
||||
|
||||
|
||||
def set_model_fields(
|
||||
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
|
||||
cls: type[BaseModel],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
) -> None:
|
||||
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
|
||||
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
|
||||
fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map)
|
||||
|
||||
cls.model_fields = fields
|
||||
cls.__pydantic_fields__ = fields
|
||||
cls.__class_vars__.update(class_vars)
|
||||
|
||||
for k in class_vars:
|
||||
@@ -435,11 +561,11 @@ def set_model_fields(
|
||||
|
||||
def complete_model_class(
|
||||
cls: type[BaseModel],
|
||||
cls_name: str,
|
||||
config_wrapper: ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
create_model_module: str | None = None,
|
||||
) -> bool:
|
||||
"""Finish building a model class.
|
||||
|
||||
@@ -448,10 +574,10 @@ def complete_model_class(
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
cls_name: The model or dataclass name.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
ns_resolver: The namespace resolver instance to use during schema building.
|
||||
create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
|
||||
Returns:
|
||||
`True` if the model is successfully completed, else `False`.
|
||||
@@ -463,132 +589,151 @@ def complete_model_class(
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
ns_resolver,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
handler = CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
)
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
schema = cls.__get_pydantic_core_schema__(cls, handler)
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_model_mocks(cls, cls_name, f'`{e.name}`')
|
||||
set_model_mocks(cls, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
|
||||
schema = gen_schema.collect_definitions(schema)
|
||||
|
||||
schema = apply_discriminators(simplify_schema_references(schema))
|
||||
if collect_invalid_schemas(schema):
|
||||
set_model_mocks(cls, cls_name)
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except InvalidSchemaError:
|
||||
set_model_mocks(cls)
|
||||
return False
|
||||
|
||||
# debug(schema)
|
||||
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
|
||||
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
|
||||
# This needs to happen *after* model schema generation, as the return type
|
||||
# of the properties are evaluated and the `ComputedFieldInfo` are recreated:
|
||||
cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
|
||||
|
||||
set_deprecated_descriptors(cls)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
cls,
|
||||
create_model_module or cls.__module__,
|
||||
cls.__qualname__,
|
||||
'create_model' if create_model_module else 'BaseModel',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
cls.__signature__ = ClassAttribute(
|
||||
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
init=cls.__init__,
|
||||
fields=cls.__pydantic_fields__,
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
|
||||
) -> Signature:
|
||||
"""Generate signature for model based on its fields.
|
||||
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
|
||||
"""Set data descriptors on the class for deprecated fields."""
|
||||
for field, field_info in cls.__pydantic_fields__.items():
|
||||
if (msg := field_info.deprecation_message) is not None:
|
||||
desc = _DeprecatedFieldDescriptor(msg)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
config_wrapper: The config wrapper instance.
|
||||
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
|
||||
if (
|
||||
(msg := computed_field_info.deprecation_message) is not None
|
||||
# Avoid having two warnings emitted:
|
||||
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
|
||||
):
|
||||
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
|
||||
Returns:
|
||||
The model signature.
|
||||
|
||||
class _DeprecatedFieldDescriptor:
|
||||
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
|
||||
|
||||
Attributes:
|
||||
msg: The deprecation message to be emitted.
|
||||
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
|
||||
field_name: The name of the field being deprecated.
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
field_name: str
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
|
||||
self.msg = msg
|
||||
self.wrapped_property = wrapped_property
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config_wrapper.populate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
if isinstance(field.alias, str):
|
||||
param_name = field.alias
|
||||
else:
|
||||
param_name = field_name
|
||||
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
|
||||
self.field_name = name
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
|
||||
if obj is None:
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(None, obj_type)
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names and is_valid_identifier(field_name):
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
|
||||
|
||||
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
|
||||
)
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(obj, obj_type)
|
||||
return obj.__dict__[self.field_name]
|
||||
|
||||
if config_wrapper.extra == 'allow':
|
||||
use_var_kw = True
|
||||
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
|
||||
# Note that it will not be called when setting a value on a model instance
|
||||
# as `BaseModel.__setattr__` is defined and takes priority.
|
||||
def __set__(self, obj: Any, value: Any) -> NoReturn:
|
||||
raise AttributeError(self.field_name)
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
|
||||
class _PydanticWeakRef:
|
||||
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
|
||||
|
||||
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
|
||||
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
|
||||
`weakref.ref` instead of subclassing it.
|
||||
|
||||
See https://github.com/pydantic/pydantic/issues/6763 for context.
|
||||
|
||||
Semantics:
|
||||
- If not pickled, behaves the same as a `weakref.ref`.
|
||||
- If pickled along with the referenced object, the same `weakref.ref` behavior
|
||||
will be maintained between them after unpickling.
|
||||
- If pickled without the referenced object, after unpickling the underlying
|
||||
reference will be cleared (`__call__` will always return `None`).
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
if obj is None:
|
||||
# The object will be `None` upon deserialization if the serialized weakref
|
||||
# had lost its underlying object.
|
||||
self._wr = None
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
self._wr = weakref.ref(obj)
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
def __call__(self) -> Any:
|
||||
if self._wr is None:
|
||||
return None
|
||||
else:
|
||||
return self._wr()
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
class _PydanticWeakRef(weakref.ReferenceType):
|
||||
pass
|
||||
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
|
||||
return _PydanticWeakRef, (self(),)
|
||||
|
||||
|
||||
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
@@ -625,3 +770,23 @@ def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | N
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@cache
|
||||
def default_ignored_types() -> tuple[type[Any], ...]:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
ignored_types = [
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
TypeAliasType, # from `typing_extensions`
|
||||
]
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
ignored_types.append(typing.TypeAliasType)
|
||||
|
||||
return tuple(ignored_types)
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
|
||||
|
||||
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
|
||||
"""A global namespace.
|
||||
|
||||
In most cases, this is a reference to the `__dict__` attribute of a module.
|
||||
This namespace type is expected as the `globals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
MappingNamespace: TypeAlias = Mapping[str, Any]
|
||||
"""Any kind of namespace.
|
||||
|
||||
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
|
||||
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
|
||||
defined inside functions).
|
||||
This namespace type is expected as the `locals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
|
||||
|
||||
|
||||
class NamespacesTuple(NamedTuple):
|
||||
"""A tuple of globals and locals to be used during annotations evaluation.
|
||||
|
||||
This datastructure is defined as a named tuple so that it can easily be unpacked:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
|
||||
return eval(typ, *ns)
|
||||
```
|
||||
"""
|
||||
|
||||
globals: GlobalsNamespace
|
||||
"""The namespace to be used as the `globals` argument during annotations evaluation."""
|
||||
|
||||
locals: MappingNamespace
|
||||
"""The namespace to be used as the `locals` argument during annotations evaluation."""
|
||||
|
||||
|
||||
def get_module_ns_of(obj: Any) -> dict[str, Any]:
|
||||
"""Get the namespace of the module where the object is defined.
|
||||
|
||||
Caution: this function does not return a copy of the module namespace, so the result
|
||||
should not be mutated. The burden of enforcing this is on the caller.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
return sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
|
||||
# immutable mappings here:
|
||||
class LazyLocalNamespace(Mapping[str, Any]):
|
||||
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
|
||||
|
||||
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
|
||||
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
|
||||
is fully implemented only for type checking purposes.
|
||||
|
||||
Args:
|
||||
*namespaces: The namespaces to consider, in ascending order of priority.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
|
||||
ns['a']
|
||||
#> 3
|
||||
ns['b']
|
||||
#> 2
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *namespaces: MappingNamespace) -> None:
|
||||
self._namespaces = namespaces
|
||||
|
||||
@cached_property
|
||||
def data(self) -> dict[str, Any]:
|
||||
return {k: v for ns in self._namespaces for k, v in ns.items()}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
|
||||
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
|
||||
|
||||
The global namespace will be the `__dict__` attribute of the module the function was defined in.
|
||||
The local namespace will contain the `__type_params__` introduced by PEP 695.
|
||||
|
||||
Args:
|
||||
obj: The object to use when building namespaces.
|
||||
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
|
||||
If the passed function is a method, the `parent_namespace` will be the namespace of the class
|
||||
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
|
||||
class-scoped type variables).
|
||||
"""
|
||||
locals_list: list[MappingNamespace] = []
|
||||
if parent_namespace is not None:
|
||||
locals_list.append(parent_namespace)
|
||||
|
||||
# Get the `__type_params__` attribute introduced by PEP 695.
|
||||
# Note that the `typing._eval_type` function expects type params to be
|
||||
# passed as a separate argument. However, internally, `_eval_type` calls
|
||||
# `ForwardRef._evaluate` which will merge type params with the localns,
|
||||
# essentially mimicking what we do here.
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ())
|
||||
if parent_namespace is not None:
|
||||
# We also fetch type params from the parent namespace. If present, it probably
|
||||
# means the function was defined in a class. This is to support the following:
|
||||
# https://github.com/python/cpython/issues/124089.
|
||||
type_params += parent_namespace.get('__type_params__', ())
|
||||
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# What about short-cirtuiting to `obj.__globals__`?
|
||||
globalns = get_module_ns_of(obj)
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
|
||||
class NsResolver:
|
||||
"""A class responsible for the namespaces resolving logic for annotations evaluation.
|
||||
|
||||
This class handles the namespace logic when evaluating annotations mainly for class objects.
|
||||
|
||||
It holds a stack of classes that are being inspected during the core schema building,
|
||||
and the `types_namespace` property exposes the globals and locals to be used for
|
||||
type annotation evaluation. Additionally -- if no class is present in the stack -- a
|
||||
fallback globals and locals can be provided using the `namespaces_tuple` argument
|
||||
(this is useful when generating a schema for a simple annotation, e.g. when using
|
||||
`TypeAdapter`).
|
||||
|
||||
The namespace creation logic is unfortunately flawed in some cases, for backwards
|
||||
compatibility reasons and to better support valid edge cases. See the description
|
||||
for the `parent_namespace` argument and the example for more details.
|
||||
|
||||
Args:
|
||||
namespaces_tuple: The default globals and locals to use if no class is present
|
||||
on the stack. This can be useful when using the `GenerateSchema` class
|
||||
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
|
||||
parent_namespace: An optional parent namespace that will be added to the locals
|
||||
with the lowest priority. For a given class defined in a function, the locals
|
||||
of this function are usually used as the parent namespace:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from pydantic import BaseModel
|
||||
|
||||
def func() -> None:
|
||||
SomeType = int
|
||||
|
||||
class Model(BaseModel):
|
||||
f: 'SomeType'
|
||||
|
||||
# when collecting fields, an namespace resolver instance will be created
|
||||
# this way:
|
||||
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
|
||||
```
|
||||
|
||||
For backwards compatibility reasons and to support valid edge cases, this parent
|
||||
namespace will be used for *every* type being pushed to the stack. In the future,
|
||||
we might want to be smarter by only doing so when the type being pushed is defined
|
||||
in the same module as the parent namespace.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns_resolver = NsResolver(
|
||||
parent_namespace={'fallback': 1},
|
||||
)
|
||||
|
||||
class Sub:
|
||||
m: 'Model'
|
||||
|
||||
class Model:
|
||||
some_local = 1
|
||||
sub: Sub
|
||||
|
||||
ns_resolver = NsResolver()
|
||||
|
||||
# This is roughly what happens when we build a core schema for `Model`:
|
||||
with ns_resolver.push(Model):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
|
||||
# First thing to notice here, the model being pushed is added to the locals.
|
||||
# Because `NsResolver` is being used during the model definition, it is not
|
||||
# yet added to the globals. This is useful when resolving self-referencing annotations.
|
||||
|
||||
with ns_resolver.push(Sub):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
|
||||
# Second thing to notice: `Sub` is present in both the globals and locals.
|
||||
# This is not an issue, just that as described above, the model being pushed
|
||||
# is added to the locals, but it happens to be present in the globals as well
|
||||
# because it is already defined.
|
||||
# Third thing to notice: `Model` is also added in locals. This is a backwards
|
||||
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
|
||||
# correctly (as otherwise models would have to be rebuilt even though this
|
||||
# doesn't look necessary).
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespaces_tuple: NamespacesTuple | None = None,
|
||||
parent_namespace: MappingNamespace | None = None,
|
||||
) -> None:
|
||||
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
|
||||
self._parent_ns = parent_namespace
|
||||
self._types_stack: list[type[Any] | TypeAliasType] = []
|
||||
|
||||
@cached_property
|
||||
def types_namespace(self) -> NamespacesTuple:
|
||||
"""The current global and local namespaces to be used for annotations evaluation."""
|
||||
if not self._types_stack:
|
||||
# TODO: should we merge the parent namespace here?
|
||||
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
|
||||
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
|
||||
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
|
||||
# we might consider something like:
|
||||
# if self._parent_ns is not None:
|
||||
# # Hacky workarounds, see class docstring:
|
||||
# # An optional parent namespace that will be added to the locals with the lowest priority
|
||||
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
|
||||
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
|
||||
return self._base_ns_tuple
|
||||
|
||||
typ = self._types_stack[-1]
|
||||
|
||||
globalns = get_module_ns_of(typ)
|
||||
|
||||
locals_list: list[MappingNamespace] = []
|
||||
# Hacky workarounds, see class docstring:
|
||||
# An optional parent namespace that will be added to the locals with the lowest priority
|
||||
if self._parent_ns is not None:
|
||||
locals_list.append(self._parent_ns)
|
||||
if len(self._types_stack) > 1:
|
||||
first_type = self._types_stack[0]
|
||||
locals_list.append({first_type.__name__: first_type})
|
||||
|
||||
# Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority
|
||||
# (see https://github.com/python/cpython/pull/120272).
|
||||
# TODO `typ.__type_params__` when we drop support for Python 3.11:
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ())
|
||||
if type_params:
|
||||
# Adding `__type_params__` is mostly useful for generic classes defined using
|
||||
# PEP 695 syntax *and* using forward annotations (see the example in
|
||||
# https://github.com/python/cpython/issues/114053). For TypeAliasType instances,
|
||||
# it is way less common, but still required if using a string annotation in the alias
|
||||
# value, e.g. `type A[T] = 'T'` (which is not necessary in most cases).
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# TypeAliasType instances don't have a `__dict__` attribute, so the check
|
||||
# is necessary:
|
||||
if hasattr(typ, '__dict__'):
|
||||
locals_list.append(vars(typ))
|
||||
|
||||
# The `len(self._types_stack) > 1` check above prevents this from being added twice:
|
||||
locals_list.append({typ.__name__: typ})
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
@contextmanager
|
||||
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
|
||||
"""Push a type to the stack."""
|
||||
self._types_stack.append(typ)
|
||||
# Reset the cached property:
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._types_stack.pop()
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tools to provide pretty/human-readable display of objects."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
@@ -6,6 +7,8 @@ import typing
|
||||
from typing import Any
|
||||
|
||||
import typing_extensions
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
|
||||
@@ -32,7 +35,7 @@ class Representation:
|
||||
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
|
||||
|
||||
# we don't want to use a type annotation here as it can break get_type_hints
|
||||
__slots__ = tuple() # type: typing.Collection[str]
|
||||
__slots__ = () # type: typing.Collection[str]
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
@@ -45,12 +48,17 @@ class Representation:
|
||||
if not attrs_names and hasattr(self, '__dict__'):
|
||||
attrs_names = self.__dict__.keys()
|
||||
attrs = ((s, getattr(self, s)) for s in attrs_names)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_recursion__(self, object: Any) -> str:
|
||||
"""Returns the string representation of a recursive object."""
|
||||
# This is copied over from the stdlib `pprint` module:
|
||||
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
@@ -87,25 +95,30 @@ def display_as_type(obj: Any) -> str:
|
||||
|
||||
Takes some logic from `typing._type_repr`.
|
||||
"""
|
||||
if isinstance(obj, types.FunctionType):
|
||||
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
|
||||
return obj.__name__
|
||||
elif obj is ...:
|
||||
return '...'
|
||||
elif isinstance(obj, Representation):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, typing.ForwardRef) or typing_objects.is_typealiastype(obj):
|
||||
return str(obj)
|
||||
|
||||
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
|
||||
obj = obj.__class__
|
||||
|
||||
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
|
||||
if is_union_origin(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'Union[{args}]'
|
||||
elif isinstance(obj, _typing_extra.WithArgsTypes):
|
||||
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
|
||||
if typing_objects.is_literal(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
|
||||
else:
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
try:
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
except AttributeError:
|
||||
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
|
||||
elif isinstance(obj, type):
|
||||
return obj.__qualname__
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
|
||||
|
||||
|
||||
class GatherResult(TypedDict):
|
||||
"""Schema traversing result."""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None]
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
"""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema]
|
||||
"""The list of core schemas having the discriminator application deferred."""
|
||||
|
||||
|
||||
class MissingDefinitionError(LookupError):
|
||||
"""A reference was pointing to a non-existing core schema."""
|
||||
|
||||
def __init__(self, schema_reference: str, /) -> None:
|
||||
self.schema_reference = schema_reference
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatherContext:
|
||||
"""The current context used during core schema traversing.
|
||||
|
||||
Context instances should only be used during schema traversing.
|
||||
"""
|
||||
|
||||
definitions: dict[str, CoreSchema]
|
||||
"""The available definitions."""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
|
||||
"""The list of core schemas having the discriminator application deferred.
|
||||
|
||||
Internally, these core schemas have a specific key set in the core metadata dict.
|
||||
"""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
|
||||
During schema traversing, definition reference schemas can be added as candidates, or removed
|
||||
(by setting the value to `None`).
|
||||
"""
|
||||
|
||||
|
||||
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
|
||||
meta = schema.get('metadata')
|
||||
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
|
||||
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
|
||||
schema_ref = def_ref_schema['schema_ref']
|
||||
|
||||
if schema_ref not in ctx.collected_references:
|
||||
definition = ctx.definitions.get(schema_ref)
|
||||
if definition is None:
|
||||
raise MissingDefinitionError(schema_ref)
|
||||
|
||||
# The `'definition-ref'` schema was only encountered once, make it
|
||||
# a candidate to be inlined:
|
||||
ctx.collected_references[schema_ref] = def_ref_schema
|
||||
traverse_schema(definition, ctx)
|
||||
if 'serialization' in def_ref_schema:
|
||||
traverse_schema(def_ref_schema['serialization'], ctx)
|
||||
traverse_metadata(def_ref_schema, ctx)
|
||||
else:
|
||||
# The `'definition-ref'` schema was already encountered, meaning
|
||||
# the previously encountered schema (and this one) can't be inlined:
|
||||
ctx.collected_references[schema_ref] = None
|
||||
|
||||
|
||||
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
|
||||
# TODO When we drop 3.9, use a match statement to get better type checking and remove
|
||||
# file-level type ignore.
|
||||
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
|
||||
schema_type = schema['type']
|
||||
|
||||
if schema_type == 'definition-ref':
|
||||
traverse_definition_ref(schema, context)
|
||||
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
|
||||
return
|
||||
elif schema_type == 'definitions':
|
||||
traverse_schema(schema['schema'], context)
|
||||
for definition in schema['definitions']:
|
||||
traverse_schema(definition, context)
|
||||
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
|
||||
if 'items_schema' in schema:
|
||||
traverse_schema(schema['items_schema'], context)
|
||||
elif schema_type == 'tuple':
|
||||
if 'items_schema' in schema:
|
||||
for s in schema['items_schema']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dict':
|
||||
if 'keys_schema' in schema:
|
||||
traverse_schema(schema['keys_schema'], context)
|
||||
if 'values_schema' in schema:
|
||||
traverse_schema(schema['values_schema'], context)
|
||||
elif schema_type == 'union':
|
||||
for choice in schema['choices']:
|
||||
if isinstance(choice, tuple):
|
||||
traverse_schema(choice[0], context)
|
||||
else:
|
||||
traverse_schema(choice, context)
|
||||
elif schema_type == 'tagged-union':
|
||||
for v in schema['choices'].values():
|
||||
traverse_schema(v, context)
|
||||
elif schema_type == 'chain':
|
||||
for step in schema['steps']:
|
||||
traverse_schema(step, context)
|
||||
elif schema_type == 'lax-or-strict':
|
||||
traverse_schema(schema['lax_schema'], context)
|
||||
traverse_schema(schema['strict_schema'], context)
|
||||
elif schema_type == 'json-or-python':
|
||||
traverse_schema(schema['json_schema'], context)
|
||||
traverse_schema(schema['python_schema'], context)
|
||||
elif schema_type in {'model-fields', 'typed-dict'}:
|
||||
if 'extras_schema' in schema:
|
||||
traverse_schema(schema['extras_schema'], context)
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields'].values():
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dataclass-args':
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'arguments':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
if 'var_args_schema' in schema:
|
||||
traverse_schema(schema['var_args_schema'], context)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
traverse_schema(schema['var_kwargs_schema'], context)
|
||||
elif schema_type == 'arguments-v3':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
elif schema_type == 'call':
|
||||
traverse_schema(schema['arguments_schema'], context)
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'computed-field':
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'function-before':
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-plain':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-wrap':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
else:
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
|
||||
if 'serialization' in schema:
|
||||
traverse_schema(schema['serialization'], context)
|
||||
traverse_metadata(schema, context)
|
||||
|
||||
|
||||
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
|
||||
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
|
||||
|
||||
During the core schema traversing, any `'definition-ref'` schema is:
|
||||
|
||||
- Validated: the reference must point to an existing definition. If this is not the case, a
|
||||
`MissingDefinitionError` exception is raised.
|
||||
- Stored in the context: the actual reference is stored in the context. Depending on whether
|
||||
the `'definition-ref'` schema is encountered more that once, the schema itself is also
|
||||
saved in the context to be inlined (i.e. replaced by the definition it points to).
|
||||
"""
|
||||
context = GatherContext(definitions)
|
||||
traverse_schema(schema, context)
|
||||
|
||||
return {
|
||||
'collected_references': context.collected_references,
|
||||
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Types and utility functions used by various other internal tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
||||
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from ._core_utils import CoreSchemaOrField
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._namespace_utils import NamespacesTuple
|
||||
|
||||
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
|
||||
@@ -32,8 +33,8 @@ class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
|
||||
self.handler = handler_override or generate_json_schema.generate_inner
|
||||
self.mode = generate_json_schema.mode
|
||||
|
||||
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
|
||||
return self.handler(__core_schema)
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
return self.handler(core_schema)
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Resolves `$ref` in the json schema.
|
||||
@@ -78,22 +79,21 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
self._generate_schema = generate_schema
|
||||
self._ref_mode = ref_mode
|
||||
|
||||
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
schema = self._handler(__source_type)
|
||||
ref = schema.get('ref')
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
schema = self._handler(source_type)
|
||||
if self._ref_mode == 'to-def':
|
||||
ref = schema.get('ref')
|
||||
if ref is not None:
|
||||
self._generate_schema.defs.definitions[ref] = schema
|
||||
return core_schema.definition_reference_schema(ref)
|
||||
return self._generate_schema.defs.create_definition_reference_schema(schema)
|
||||
return schema
|
||||
else: # ref_mode = 'unpack
|
||||
else: # ref_mode = 'unpack'
|
||||
return self.resolve_ref_schema(schema)
|
||||
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
return self._generate_schema._types_namespace
|
||||
|
||||
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(__source_type)
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(source_type)
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
@@ -113,12 +113,13 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
"""
|
||||
if maybe_ref_schema['type'] == 'definition-ref':
|
||||
ref = maybe_ref_schema['schema_ref']
|
||||
if ref not in self._generate_schema.defs.definitions:
|
||||
definition = self._generate_schema.defs.get_schema_from_ref(ref)
|
||||
if definition is None:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return self._generate_schema.defs.definitions[ref]
|
||||
return definition
|
||||
elif maybe_ref_schema['type'] == 'definitions':
|
||||
return self.resolve_ref_schema(maybe_ref_schema['schema'])
|
||||
return maybe_ref_schema
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticOmit, core_schema
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque, # noqa: UP006
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list, # noqa: UP006
|
||||
tuple: tuple,
|
||||
typing.Tuple: tuple, # noqa: UP006
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set, # noqa: UP006
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset, # noqa: UP006
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
|
||||
if mapped_origin is None:
|
||||
# we shouldn't hit this branch, should probably add a serialization error or something
|
||||
return v
|
||||
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return mapped_origin(items)
|
||||
@@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import ExtraValues
|
||||
from ..fields import FieldInfo
|
||||
|
||||
|
||||
# Copied over from stdlib dataclasses
|
||||
class _HAS_DEFAULT_FACTORY_CLASS:
|
||||
def __repr__(self):
|
||||
return '<factory>'
|
||||
|
||||
|
||||
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
|
||||
|
||||
|
||||
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
|
||||
"""Extract the correct name to use for the field when generating a signature.
|
||||
|
||||
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
|
||||
First priority is given to the alias, then the validation_alias, then the field name.
|
||||
|
||||
Args:
|
||||
field_name: The name of the field
|
||||
field_info: The corresponding FieldInfo object.
|
||||
|
||||
Returns:
|
||||
The correct name to use when generating a signature.
|
||||
"""
|
||||
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
|
||||
return field_info.alias
|
||||
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
|
||||
return field_info.validation_alias
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def _process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(
|
||||
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
|
||||
)
|
||||
return param
|
||||
|
||||
|
||||
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
) -> dict[str, Parameter]:
|
||||
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if fields.get(param.name):
|
||||
# exclude params with init=False
|
||||
if getattr(fields[param.name], 'init', True) is False:
|
||||
continue
|
||||
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = validate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
param_name = _field_name_for_signature(field_name, field)
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names:
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
if field.is_required():
|
||||
default = Parameter.empty
|
||||
elif field.default_factory is not None:
|
||||
# Mimics stdlib dataclasses:
|
||||
default = _HAS_DEFAULT_FACTORY
|
||||
else:
|
||||
default = field.default
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name,
|
||||
Parameter.KEYWORD_ONLY,
|
||||
annotation=field.rebuild_annotation(),
|
||||
default=default,
|
||||
)
|
||||
|
||||
if extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('self', Parameter.POSITIONAL_ONLY),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return merged_params
|
||||
|
||||
|
||||
def generate_pydantic_signature(
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
is_dataclass: bool = False,
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic BaseModel or dataclass.
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
validate_by_name: The `validate_by_name` value of the config.
|
||||
extra: The `extra` value of the config.
|
||||
is_dataclass: Whether the model is a dataclass.
|
||||
|
||||
Returns:
|
||||
The dataclass/BaseModel subclass signature.
|
||||
"""
|
||||
merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
|
||||
|
||||
if is_dataclass:
|
||||
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
@@ -1,713 +0,0 @@
|
||||
"""Logic for generating pydantic-core schemas for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import decimal
|
||||
import inspect
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable, Iterable, TypeVar
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
MultiHostUrl,
|
||||
PydanticCustomError,
|
||||
PydanticOmit,
|
||||
Url,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.types import Strict
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..json_schema import JsonSchemaValue, update_json_schema
|
||||
from . import _known_annotated_metadata, _typing_extra, _validators
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SchemaTransformer:
|
||||
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
|
||||
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.get_core_schema(source_type, handler)
|
||||
|
||||
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
return self.get_json_schema(schema, handler)
|
||||
|
||||
|
||||
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
|
||||
cases: list[Any] = list(enum_type.__members__.values())
|
||||
|
||||
if not cases:
|
||||
# Use an isinstance check for enums with no cases.
|
||||
# This won't work with serialization or JSON schema, but that's okay -- the most important
|
||||
# use case for this is creating typevar bounds for generics that should be restricted to enums.
|
||||
# This is more consistent than it might seem at first, since you can only subclass enum.Enum
|
||||
# (or subclasses of enum.Enum) if all parent classes have no cases.
|
||||
return core_schema.is_instance_schema(enum_type)
|
||||
|
||||
use_enum_values = config.get('use_enum_values', False)
|
||||
|
||||
if len(cases) == 1:
|
||||
expected = repr(cases[0].value)
|
||||
else:
|
||||
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
|
||||
|
||||
def to_enum(__input_value: Any) -> Enum:
|
||||
try:
|
||||
enum_field = enum_type(__input_value)
|
||||
if use_enum_values:
|
||||
return enum_field.value
|
||||
return enum_field
|
||||
except ValueError:
|
||||
# The type: ignore on the next line is to ignore the requirement of LiteralString
|
||||
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
|
||||
|
||||
enum_ref = get_type_ref(enum_type)
|
||||
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
|
||||
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
|
||||
description = None
|
||||
updates = {'title': enum_type.__name__, 'description': description}
|
||||
updates = {k: v for k, v in updates.items() if v is not None}
|
||||
|
||||
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
|
||||
original_schema = handler.resolve_ref_schema(json_schema)
|
||||
update_json_schema(original_schema, updates)
|
||||
return json_schema
|
||||
|
||||
strict_python_schema = core_schema.is_instance_schema(enum_type)
|
||||
if use_enum_values:
|
||||
strict_python_schema = core_schema.chain_schema(
|
||||
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
|
||||
)
|
||||
|
||||
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
|
||||
if issubclass(enum_type, int):
|
||||
# this handles `IntEnum`, and also `Foobar(int, Enum)`
|
||||
updates['type'] = 'integer'
|
||||
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
|
||||
# Disallow float from JSON due to strict mode
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, str):
|
||||
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
|
||||
updates['type'] = 'string'
|
||||
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, float):
|
||||
updates['type'] = 'numeric'
|
||||
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
else:
|
||||
lax = to_enum_validator
|
||||
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
|
||||
return core_schema.lax_or_strict_schema(
|
||||
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class InnerSchemaValidator:
|
||||
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
|
||||
|
||||
core_schema: CoreSchema
|
||||
js_schema: JsonSchemaValue | None = None
|
||||
js_core_schema: CoreSchema | None = None
|
||||
js_schema_update: JsonSchemaValue | None = None
|
||||
|
||||
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
if self.js_schema is not None:
|
||||
return self.js_schema
|
||||
js_schema = handler(self.js_core_schema or self.core_schema)
|
||||
if self.js_schema_update is not None:
|
||||
js_schema.update(self.js_schema_update)
|
||||
return js_schema
|
||||
|
||||
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.core_schema
|
||||
|
||||
|
||||
def decimal_prepare_pydantic_annotations(
|
||||
source: Any, annotations: Iterable[Any], config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source is not decimal.Decimal:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
|
||||
config_allow_inf_nan = config.get('allow_inf_nan')
|
||||
if config_allow_inf_nan is not None:
|
||||
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
|
||||
|
||||
_known_annotated_metadata.check_metadata(
|
||||
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
|
||||
)
|
||||
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
|
||||
|
||||
|
||||
def datetime_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import datetime
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
if source_type is datetime.date:
|
||||
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
|
||||
elif source_type is datetime.datetime:
|
||||
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
|
||||
elif source_type is datetime.time:
|
||||
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
|
||||
elif source_type is datetime.timedelta:
|
||||
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
|
||||
else:
|
||||
return None
|
||||
# check now that we know the source type is correct
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
|
||||
return (source_type, [sv, *remaining_annotations])
|
||||
|
||||
|
||||
def uuid_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
if source_type is not UUID:
|
||||
return None
|
||||
|
||||
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
|
||||
|
||||
|
||||
def path_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import pathlib
|
||||
|
||||
if source_type not in {
|
||||
os.PathLike,
|
||||
pathlib.Path,
|
||||
pathlib.PurePath,
|
||||
pathlib.PosixPath,
|
||||
pathlib.PurePosixPath,
|
||||
pathlib.PureWindowsPath,
|
||||
}:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
|
||||
|
||||
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
|
||||
|
||||
def path_validator(input_value: str) -> os.PathLike[Any]:
|
||||
try:
|
||||
return construct_path(input_value)
|
||||
except TypeError as e:
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
|
||||
|
||||
constrained_str_schema = core_schema.str_schema(**metadata)
|
||||
|
||||
instance_schema = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
python_schema=core_schema.is_instance_schema(source_type),
|
||||
)
|
||||
|
||||
strict: bool | None = None
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, Strict):
|
||||
strict = annotation.strict
|
||||
|
||||
schema = core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.union_schema(
|
||||
[
|
||||
instance_schema,
|
||||
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
],
|
||||
custom_error_type='path_type',
|
||||
custom_error_message='Input is not a valid path',
|
||||
strict=True,
|
||||
),
|
||||
strict_schema=instance_schema,
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def dequeue_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
|
||||
) -> collections.deque[Any]:
|
||||
if isinstance(input_value, collections.deque):
|
||||
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
|
||||
if maxlens:
|
||||
maxlen = min(maxlens)
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
else:
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SequenceValidator:
|
||||
mapped_origin: type[Any]
|
||||
item_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return self.mapped_origin(items)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.item_source_type is Any:
|
||||
items_schema = None
|
||||
else:
|
||||
items_schema = handler.generate_schema(self.item_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin in (list, set, frozenset):
|
||||
if self.mapped_origin is list:
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
elif self.mapped_origin is set:
|
||||
constrained_schema = core_schema.set_schema(items_schema, **metadata)
|
||||
else:
|
||||
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
|
||||
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
|
||||
|
||||
schema = constrained_schema
|
||||
else:
|
||||
# safety check in case we forget to add a case
|
||||
assert self.mapped_origin in (collections.deque, collections.Counter)
|
||||
|
||||
if self.mapped_origin is collections.deque:
|
||||
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
|
||||
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
|
||||
# that e.g. comes from JSON
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.list_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque,
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list,
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set,
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset,
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def identity(s: CoreSchema) -> CoreSchema:
|
||||
return s
|
||||
|
||||
|
||||
def sequence_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any,)
|
||||
elif len(args) != 1:
|
||||
raise ValueError('Expected sequence to have exactly 1 generic parameter')
|
||||
|
||||
item_source_type = args[0]
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
|
||||
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict,
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
dict: dict,
|
||||
typing.Dict: dict,
|
||||
collections.Counter: collections.Counter,
|
||||
typing.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.MutableMapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
}
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
typing.Tuple: tuple,
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
typing.List: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
typing.Set: set,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type_origin = get_origin(values_source_type) or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if isinstance(values_type_origin, TypeVar):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type_origin not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type_origin]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if _typing_extra.is_annotated(values_source_type):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
default_default_factory = field_info.default_factory
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class MappingValidator:
|
||||
mapped_origin: type[Any]
|
||||
keys_source_type: type[Any]
|
||||
values_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
|
||||
return handler(v)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.keys_source_type is Any:
|
||||
keys_schema = None
|
||||
else:
|
||||
keys_schema = handler.generate_schema(self.keys_source_type)
|
||||
if self.values_source_type is Any:
|
||||
values_schema = None
|
||||
else:
|
||||
values_schema = handler.generate_schema(self.values_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin is dict:
|
||||
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
else:
|
||||
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.dict_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
if self.mapped_origin is collections.defaultdict:
|
||||
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(defaultdict_validator, default_default_factory=default_default_factory),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_mapping_via_dict,
|
||||
schema=core_schema.dict_schema(
|
||||
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
|
||||
),
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def mapping_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any, Any)
|
||||
elif mapped_origin is collections.Counter:
|
||||
# a single generic
|
||||
if len(args) != 1:
|
||||
raise ValueError('Expected Counter to have exactly 1 generic parameter')
|
||||
args = (args[0], int) # keys are always an int
|
||||
elif len(args) != 2:
|
||||
raise ValueError('Expected mapping to have exactly 2 generic parameters')
|
||||
|
||||
keys_source_type, values_source_type = args
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def ip_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
|
||||
python_schema=core_schema.is_instance_schema(tp),
|
||||
)
|
||||
|
||||
if source_type is IPv4Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
if source_type is IPv6Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def url_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source_type is Url:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is MultiHostUrl:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.multi_host_url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
|
||||
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
|
||||
decimal_prepare_pydantic_annotations,
|
||||
sequence_like_prepare_pydantic_annotations,
|
||||
datetime_prepare_pydantic_annotations,
|
||||
uuid_prepare_pydantic_annotations,
|
||||
path_schema_prepare_pydantic_annotations,
|
||||
mapping_like_prepare_pydantic_annotations,
|
||||
ip_prepare_pydantic_annotations,
|
||||
url_prepare_pydantic_annotations,
|
||||
)
|
||||
@@ -1,244 +1,544 @@
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
|
||||
from __future__ import annotations as _annotations
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
|
||||
|
||||
import dataclasses
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from types import GetSetDescriptorType
|
||||
from typing import TYPE_CHECKING, Any, ForwardRef
|
||||
from typing import TYPE_CHECKING, Any, Callable, cast
|
||||
|
||||
from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypeGuard, get_args, get_origin
|
||||
import typing_extensions
|
||||
from typing_extensions import deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._dataclasses import StandardDataclass
|
||||
|
||||
try:
|
||||
from typing import _TypingBase # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from typing import _Final as _TypingBase # type: ignore[attr-defined]
|
||||
|
||||
typing_base = _TypingBase
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
else:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired, Required
|
||||
else:
|
||||
from typing import NotRequired, Required # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union or tp is types.UnionType
|
||||
|
||||
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
|
||||
from pydantic.version import version_short
|
||||
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType as EllipsisType
|
||||
from types import NoneType as NoneType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
LITERAL_TYPES: set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal) # type: ignore
|
||||
|
||||
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
|
||||
# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types,
|
||||
# always check for both `typing` and `typing_extensions` variants of a typing construct.
|
||||
# (this is implemented differently than the suggested approach in the `typing_extensions`
|
||||
# docs for performance).
|
||||
|
||||
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
_t_annotated = typing.Annotated
|
||||
_te_annotated = typing_extensions.Annotated
|
||||
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
def is_annotated(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Annotated` special form.
|
||||
|
||||
|
||||
def is_callable_type(type_: type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: type[Any]) -> list[Any]:
|
||||
"""This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
|
||||
```python {test="skip" lint="skip"}
|
||||
is_annotated(Annotated[int, ...])
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return [type_]
|
||||
|
||||
values = literal_values(type_)
|
||||
return list(x for value in values for x in all_literal_values(value))
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_annotated or origin is _te_annotated
|
||||
|
||||
|
||||
def is_annotated(ann_type: Any) -> bool:
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
origin = get_origin(ann_type)
|
||||
return origin is not None and lenient_issubclass(origin, Annotated)
|
||||
def annotated_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type of the `Annotated` special form, or `None`."""
|
||||
return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None
|
||||
|
||||
|
||||
def is_namedtuple(type_: type[Any]) -> bool:
|
||||
"""Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
|
||||
def unpack_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type wrapped by the `Unpack` special form, or `None`."""
|
||||
return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None
|
||||
|
||||
|
||||
def is_hashable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Hashable` class.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_hashable(Hashable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
|
||||
|
||||
|
||||
test_new_type = typing.NewType('test_new_type', str)
|
||||
def is_callable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Callable`, parametrized or not.
|
||||
|
||||
|
||||
def is_new_type(type_: type[Any]) -> bool:
|
||||
"""Check whether type_ was created using typing.NewType.
|
||||
|
||||
Can't use isinstance because it fails <3.10.
|
||||
```python {test="skip" lint="skip"}
|
||||
is_callable(Callable[[int], str])
|
||||
#> True
|
||||
is_callable(typing.Callable)
|
||||
#> True
|
||||
is_callable(collections.abc.Callable)
|
||||
#> True
|
||||
```
|
||||
"""
|
||||
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
|
||||
|
||||
|
||||
def _check_classvar(v: type[Any] | None) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
|
||||
|
||||
|
||||
def is_classvar(ann_type: type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
def is_classvar_annotation(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument represents a class variable annotation.
|
||||
|
||||
Although not explicitly stated by the typing specification, `ClassVar` can be used
|
||||
inside `Annotated` and as such, this function checks for this specific scenario.
|
||||
|
||||
Because this function is used to detect class variables before evaluating forward references
|
||||
(or because evaluation failed), we also implement a naive regex match implementation. This is
|
||||
required because class variables are inspected before fields are collected, so we try to be
|
||||
as accurate as possible.
|
||||
"""
|
||||
if typing_objects.is_classvar(tp):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
|
||||
origin = get_origin(tp)
|
||||
|
||||
if typing_objects.is_classvar(origin):
|
||||
return True
|
||||
|
||||
if typing_objects.is_annotated(origin):
|
||||
annotated_type = tp.__origin__
|
||||
if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)):
|
||||
return True
|
||||
|
||||
str_ann: str | None = None
|
||||
if isinstance(tp, typing.ForwardRef):
|
||||
str_ann = tp.__forward_arg__
|
||||
if isinstance(tp, str):
|
||||
str_ann = tp
|
||||
|
||||
if str_ann is not None and _classvar_re.match(str_ann):
|
||||
# stdlib dataclasses do something similar, although a bit more advanced
|
||||
# (see `dataclass._is_type`).
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _check_finalvar(v: type[Any] | None) -> bool:
|
||||
"""Check if a given type is a `typing.Final` type."""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
_t_final = typing.Final
|
||||
_te_final = typing_extensions.Final
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Any) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
|
||||
def is_finalvar(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Final` special form, parametrized or not.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_finalvar(Final[int])
|
||||
#> True
|
||||
is_finalvar(Final)
|
||||
#> True
|
||||
"""
|
||||
# Final is not necessarily parametrized:
|
||||
if tp is _t_final or tp is _te_final:
|
||||
return True
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_final or origin is _te_final
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
|
||||
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
|
||||
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
|
||||
and suggestion at the end of the next comment by @gvanrossum.
|
||||
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
|
||||
|
||||
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
|
||||
parent of where it is called.
|
||||
|
||||
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
|
||||
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
|
||||
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
|
||||
def is_none_type(tp: Any, /) -> bool:
|
||||
"""Return whether the argument represents the `None` type as part of an annotation.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_none_type(None)
|
||||
#> True
|
||||
is_none_type(NoneType)
|
||||
#> True
|
||||
is_none_type(Literal[None])
|
||||
#> True
|
||||
is_none_type(type[None])
|
||||
#> False
|
||||
"""
|
||||
return tp in _NONE_TYPES
|
||||
|
||||
|
||||
def is_namedtuple(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a named tuple class.
|
||||
|
||||
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
|
||||
Parametrized generic classes are *not* assumed to be named tuples.
|
||||
"""
|
||||
from ._utils import lenient_issubclass # circ. import
|
||||
|
||||
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
|
||||
|
||||
|
||||
# TODO In 2.12, delete this export. It is currently defined only to not break
|
||||
# pydantic-settings which relies on it:
|
||||
origin_is_union = is_union_origin
|
||||
|
||||
|
||||
def is_generic_alias(tp: Any, /) -> bool:
|
||||
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
|
||||
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
### Annotation evaluations functions:
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
|
||||
"""Fetch the local namespace of the parent frame where this function is called.
|
||||
|
||||
Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace,
|
||||
such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve
|
||||
such annotations:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from typing import get_type_hints
|
||||
|
||||
def func() -> None:
|
||||
Alias = int
|
||||
|
||||
class C:
|
||||
a: 'Alias'
|
||||
|
||||
# Raises a `NameError: 'Alias' is not defined`
|
||||
get_type_hints(C)
|
||||
```
|
||||
|
||||
Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However,
|
||||
this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function,
|
||||
itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough
|
||||
(see https://discuss.python.org/t/20659).
|
||||
|
||||
Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's
|
||||
code object is defined at the module level. In this case, the locals of the frame will be the same as the module
|
||||
globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch
|
||||
the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain
|
||||
members that you want to use for forward annotations evaluation), you can use the `force` parameter.
|
||||
|
||||
Args:
|
||||
parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function
|
||||
is called will be used.
|
||||
force: Whether to always return the frame locals, even if the frame's code object is defined at the module level.
|
||||
|
||||
Returns:
|
||||
The locals of the namespace, or `None` if it was skipped as per the described logic.
|
||||
"""
|
||||
frame = sys._getframe(parent_depth)
|
||||
# if f_back is None, it's the global module namespace and we don't need to include it here
|
||||
if frame.f_back is None:
|
||||
return None
|
||||
else:
|
||||
|
||||
if frame.f_code.co_name.startswith('<generic parameters of'):
|
||||
# As `parent_frame_namespace` is mostly called in `ModelMetaclass.__new__`,
|
||||
# the parent frame can be the annotation scope if the PEP 695 generic syntax is used.
|
||||
# (see https://docs.python.org/3/reference/executionmodel.html#annotation-scopes,
|
||||
# https://docs.python.org/3/reference/compound_stmts.html#generic-classes).
|
||||
# In this case, the code name is set to `<generic parameters of MyClass>`,
|
||||
# and we need to skip this frame as it is irrelevant.
|
||||
frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None`
|
||||
|
||||
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be
|
||||
# modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this.
|
||||
if force:
|
||||
return frame.f_locals
|
||||
|
||||
# If either of the following conditions are true, the class is defined at the top module level.
|
||||
# To better understand why we need both of these checks, see
|
||||
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531.
|
||||
if frame.f_back is None or frame.f_code.co_name == '<module>':
|
||||
return None
|
||||
|
||||
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
module_globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
if globalns:
|
||||
return {**module_globalns, **globalns}
|
||||
else:
|
||||
# copy module globals to make sure it can't be updated later
|
||||
return module_globalns.copy()
|
||||
|
||||
return globalns or {}
|
||||
return frame.f_locals
|
||||
|
||||
|
||||
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
ns = add_module_globals(cls, parent_namespace)
|
||||
ns[cls.__name__] = cls
|
||||
return ns
|
||||
def _type_convert(arg: Any) -> Any:
|
||||
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
|
||||
|
||||
|
||||
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
|
||||
This is a backport of the private `typing._type_convert` function. When
|
||||
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
|
||||
responsible for making this conversion. However, we still have to apply
|
||||
it for the first argument passed to our type evaluation functions, similarly
|
||||
to the `typing.get_type_hints` function.
|
||||
"""
|
||||
hints = {}
|
||||
if arg is None:
|
||||
return NoneType
|
||||
if isinstance(arg, str):
|
||||
# Like `typing.get_type_hints`, assume the arg can be in any context,
|
||||
# hence the proper `is_argument` and `is_class` args:
|
||||
return _make_forward_ref(arg, is_argument=False, is_class=True)
|
||||
return arg
|
||||
|
||||
|
||||
def get_model_type_hints(
|
||||
obj: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, tuple[Any, bool]]:
|
||||
"""Collect annotations from a Pydantic model class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The Pydantic model to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
|
||||
type or the original annotation if a `NameError` occurred, the second element is a boolean
|
||||
indicating if whether the evaluation succeeded.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann = base.__dict__.get('__annotations__')
|
||||
localns = dict(vars(base))
|
||||
if ann is not None and ann is not GetSetDescriptorType:
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type_lenient(value, globalns, localns)
|
||||
if name.startswith('_'):
|
||||
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
|
||||
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
|
||||
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
|
||||
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
|
||||
# syntax, etc).
|
||||
try:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
except Exception:
|
||||
hints[name] = (value, False)
|
||||
else:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
|
||||
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
def get_cls_type_hints(
|
||||
obj: type[Any],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The class to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def try_eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> tuple[Any, bool]:
|
||||
"""Try evaluating the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the possibly evaluated type and a boolean indicating
|
||||
whether the evaluation succeeded or not.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
|
||||
try:
|
||||
return typing._eval_type(value, globalns, localns) # type: ignore
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
except NameError:
|
||||
# the point of this function is to be tolerant to this case
|
||||
return value
|
||||
return value, False
|
||||
|
||||
|
||||
def eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
"""Evaluate the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
return eval_type_backport(value, globalns, localns)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
|
||||
category=None,
|
||||
)
|
||||
def eval_type_lenient(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
ev, _ = try_eval_type(value, globalns, localns)
|
||||
return ev
|
||||
|
||||
|
||||
def eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
|
||||
package if it's installed to let older Python versions use newer typing constructs.
|
||||
|
||||
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
|
||||
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
|
||||
current Python version.
|
||||
|
||||
This function will also display a helpful error if the value passed fails to evaluate.
|
||||
"""
|
||||
try:
|
||||
return _eval_type_backport(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if 'Unable to evaluate type annotation' in str(e):
|
||||
raise
|
||||
|
||||
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
|
||||
# Thus we assert here for type checking purposes:
|
||||
assert isinstance(value, typing.ForwardRef)
|
||||
|
||||
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise TypeError(message) from e
|
||||
except RecursionError as e:
|
||||
# TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport`
|
||||
# is used directly in some places.
|
||||
message = (
|
||||
"If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), "
|
||||
'consider using PEP 695 type aliases instead. For more details, refer to the documentation: '
|
||||
f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types'
|
||||
)
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise RecursionError(f'{e.args[0]}\n{message}')
|
||||
|
||||
|
||||
def _eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
return _eval_type(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
|
||||
raise
|
||||
|
||||
try:
|
||||
from eval_type_backport import eval_type_backport
|
||||
except ImportError:
|
||||
raise TypeError(
|
||||
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
|
||||
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
|
||||
'since Python 3.9), you should either replace the use of new syntax with the existing '
|
||||
'`typing` constructs or install the `eval_type_backport` package.'
|
||||
) from e
|
||||
|
||||
return eval_type_backport(
|
||||
value,
|
||||
globalns,
|
||||
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
|
||||
try_default=False,
|
||||
)
|
||||
|
||||
|
||||
def _eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
if sys.version_info >= (3, 13):
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns, type_params=type_params
|
||||
)
|
||||
else:
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns
|
||||
)
|
||||
|
||||
|
||||
def is_backport_fixable_error(e: TypeError) -> bool:
|
||||
msg = str(e)
|
||||
|
||||
return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ')
|
||||
|
||||
|
||||
def get_function_type_hints(
|
||||
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
include_keys: set[str] | None = None,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
|
||||
copes with `partial`.
|
||||
"""
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
"""Return type hints for a function.
|
||||
|
||||
This is similar to the `typing.get_type_hints` function, with a few differences:
|
||||
- Support `functools.partial` by using the underlying `func` attribute.
|
||||
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
|
||||
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
|
||||
"""
|
||||
try:
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
except AttributeError:
|
||||
# Some functions (e.g. builtins) don't have annotations:
|
||||
return {}
|
||||
|
||||
if globalns is None:
|
||||
globalns = get_module_ns_of(function)
|
||||
type_params: tuple[Any, ...] | None = None
|
||||
if localns is None:
|
||||
# If localns was specified, it is assumed to already contain type params. This is because
|
||||
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
|
||||
type_params = getattr(function, '__type_params__', ())
|
||||
|
||||
globalns = add_module_globals(function)
|
||||
type_hints = {}
|
||||
for name, value in annotations.items():
|
||||
if include_keys is not None and name not in include_keys:
|
||||
@@ -248,7 +548,7 @@ def get_function_type_hints(
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value)
|
||||
|
||||
type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore
|
||||
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
|
||||
|
||||
return type_hints
|
||||
|
||||
@@ -363,11 +663,15 @@ else:
|
||||
if isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
value = typing._eval_type(value, base_globals, base_locals) # type: ignore
|
||||
value = eval_type_backport(value, base_globals, base_locals)
|
||||
hints[name] = value
|
||||
return (
|
||||
hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
)
|
||||
if not include_extras and hasattr(typing, '_strip_annotations'):
|
||||
return {
|
||||
k: typing._strip_annotations(t) # type: ignore
|
||||
for k, t in hints.items()
|
||||
}
|
||||
else:
|
||||
return hints
|
||||
|
||||
if globalns is None:
|
||||
if isinstance(obj, types.ModuleType):
|
||||
@@ -388,7 +692,7 @@ else:
|
||||
if isinstance(obj, typing._allowed_types): # type: ignore
|
||||
return {}
|
||||
else:
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, or function.')
|
||||
defaults = typing._get_defaults(obj) # type: ignore
|
||||
hints = dict(hints)
|
||||
for name, value in hints.items():
|
||||
@@ -403,33 +707,8 @@ else:
|
||||
is_argument=not isinstance(obj, types.ModuleType),
|
||||
is_class=False,
|
||||
)
|
||||
value = typing._eval_type(value, globalns, localns) # type: ignore
|
||||
value = eval_type_backport(value, globalns, localns)
|
||||
if name in defaults and defaults[name] is None:
|
||||
value = typing.Optional[value]
|
||||
hints[name] = value
|
||||
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset())
|
||||
|
||||
|
||||
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
|
||||
# so I created this convenience function
|
||||
return dataclasses.is_dataclass(_cls)
|
||||
|
||||
|
||||
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
|
||||
return isinstance(origin, TypeAliasType)
|
||||
|
||||
@@ -2,20 +2,30 @@
|
||||
|
||||
This should be reduced as much as possible with functions only used in one place, moved to that place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import keyword
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from inspect import Parameter
|
||||
from itertools import zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar, overload
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
from typing_extensions import TypeAlias, TypeGuard, deprecated
|
||||
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
|
||||
from . import _repr, _typing_extra
|
||||
from ._import_utils import import_cached_base_model
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
|
||||
@@ -59,6 +69,25 @@ BUILTIN_COLLECTIONS: set[type[Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
"""Return whether the parameter accepts a positional argument.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
def func(a, /, b, *, c):
|
||||
pass
|
||||
|
||||
params = inspect.signature(func).parameters
|
||||
can_be_positional(params['a'])
|
||||
#> True
|
||||
can_be_positional(params['b'])
|
||||
#> True
|
||||
can_be_positional(params['c'])
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
@@ -83,7 +112,7 @@ def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
|
||||
unlike raw calls to lenient_issubclass.
|
||||
"""
|
||||
from ..main import BaseModel
|
||||
BaseModel = import_cached_base_model()
|
||||
|
||||
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
|
||||
|
||||
@@ -275,19 +304,23 @@ class ValueItems(_repr.Representation):
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def ClassAttribute(name: str, value: T) -> T:
|
||||
...
|
||||
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
|
||||
|
||||
else:
|
||||
|
||||
class ClassAttribute:
|
||||
"""Hide class attribute from its instances."""
|
||||
class LazyClassAttribute:
|
||||
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
|
||||
|
||||
__slots__ = 'name', 'value'
|
||||
The attribute is lazily computed and cached during the first access.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.get_value = get_value
|
||||
|
||||
@cached_property
|
||||
def value(self) -> Any:
|
||||
return self.get_value()
|
||||
|
||||
def __get__(self, instance: Any, owner: type[Any]) -> None:
|
||||
if instance is None:
|
||||
@@ -309,7 +342,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
@@ -317,7 +350,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
_EMPTY = object()
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
|
||||
@@ -329,7 +362,70 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SafeGetItemProxy:
|
||||
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
|
||||
|
||||
This makes is safe to use in `operator.itemgetter` when some keys may be missing
|
||||
"""
|
||||
|
||||
# Define __slots__manually for performances
|
||||
# @dataclasses.dataclass() only support slots=True in python>=3.10
|
||||
__slots__ = ('wrapped',)
|
||||
|
||||
wrapped: Mapping[str, Any]
|
||||
|
||||
def __getitem__(self, key: str, /) -> Any:
|
||||
return self.wrapped.get(key, _SENTINEL)
|
||||
|
||||
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
|
||||
# https://github.com/python/mypy/issues/13713
|
||||
# https://github.com/python/typeshed/pull/8785
|
||||
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return self.wrapped.__contains__(key)
|
||||
|
||||
|
||||
_ModelT = TypeVar('_ModelT', bound='BaseModel')
|
||||
_RT = TypeVar('_RT')
|
||||
|
||||
|
||||
class deprecated_instance_property(Generic[_ModelT, _RT]):
|
||||
"""A decorator exposing the decorated class method as a property, with a warning on instance access.
|
||||
|
||||
This decorator takes a class method defined on the `BaseModel` class and transforms it into
|
||||
an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
|
||||
via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
|
||||
"""
|
||||
|
||||
def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
|
||||
# Note: fget should be a classmethod:
|
||||
self.fget = fget
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
|
||||
@overload
|
||||
@deprecated(
|
||||
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=None,
|
||||
)
|
||||
def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
|
||||
def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
|
||||
if instance is not None:
|
||||
attr_name = self.fget.__name__ if sys.version_info >= (3, 10) else self.fget.__func__.__name__
|
||||
warnings.warn(
|
||||
f'Accessing the {attr_name!r} attribute on the instance is deprecated. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=PydanticDeprecatedSince211,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.fget.__get__(instance, objtype)()
|
||||
|
||||
@@ -1,88 +1,122 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Awaitable
|
||||
from functools import partial
|
||||
from typing import Any, Awaitable, Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import pydantic_core
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from . import _discriminated_union, _generate_schema, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._core_utils import simplify_schema_references, validate_core_schema
|
||||
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
|
||||
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallMarker:
|
||||
function: Callable[..., Any]
|
||||
validate_return: bool
|
||||
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the name of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
|
||||
|
||||
|
||||
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
|
||||
|
||||
|
||||
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
|
||||
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
|
||||
if inspect.iscoroutinefunction(wrapped):
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
async def wrapper_function(*args, **kwargs): # type: ignore
|
||||
return await wrapper(*args, **kwargs)
|
||||
else:
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
def wrapper_function(*args, **kwargs):
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
|
||||
wrapper_function.__name__ = extract_function_name(wrapped)
|
||||
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
|
||||
wrapper_function.raw_function = wrapped # type: ignore
|
||||
|
||||
return wrapper_function
|
||||
|
||||
|
||||
class ValidateCallWrapper:
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.
|
||||
|
||||
It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
|
||||
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
|
||||
"""
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
|
||||
|
||||
__slots__ = (
|
||||
'raw_function',
|
||||
'_config',
|
||||
'_validate_return',
|
||||
'__pydantic_core_schema__',
|
||||
'function',
|
||||
'validate_return',
|
||||
'schema_type',
|
||||
'module',
|
||||
'qualname',
|
||||
'ns_resolver',
|
||||
'config_wrapper',
|
||||
'__pydantic_complete__',
|
||||
'__pydantic_validator__',
|
||||
'__signature__',
|
||||
'__name__',
|
||||
'__qualname__',
|
||||
'__annotations__',
|
||||
'__dict__', # required for __module__
|
||||
'__return_pydantic_validator__',
|
||||
)
|
||||
|
||||
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
|
||||
self.raw_function = function
|
||||
self._config = config
|
||||
self._validate_return = validate_return
|
||||
self.__signature__ = inspect.signature(function)
|
||||
def __init__(
|
||||
self,
|
||||
function: ValidateCallSupportedTypes,
|
||||
config: ConfigDict | None,
|
||||
validate_return: bool,
|
||||
parent_namespace: MappingNamespace | None,
|
||||
) -> None:
|
||||
self.function = function
|
||||
self.validate_return = validate_return
|
||||
if isinstance(function, partial):
|
||||
func = function.func
|
||||
self.__name__ = f'partial({func.__name__})'
|
||||
self.__qualname__ = f'partial({func.__qualname__})'
|
||||
self.__annotations__ = func.__annotations__
|
||||
self.__module__ = func.__module__
|
||||
self.__doc__ = func.__doc__
|
||||
self.schema_type = function.func
|
||||
self.module = function.func.__module__
|
||||
else:
|
||||
self.__name__ = function.__name__
|
||||
self.__qualname__ = function.__qualname__
|
||||
self.__annotations__ = function.__annotations__
|
||||
self.__module__ = function.__module__
|
||||
self.__doc__ = function.__doc__
|
||||
self.schema_type = function
|
||||
self.module = function.__module__
|
||||
self.qualname = extract_function_qualname(function)
|
||||
|
||||
namespace = _typing_extra.add_module_globals(function, None)
|
||||
config_wrapper = ConfigWrapper(config)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
|
||||
schema = simplify_schema_references(schema)
|
||||
self.__pydantic_core_schema__ = schema = schema
|
||||
core_config = config_wrapper.core_config(self)
|
||||
schema = _discriminated_union.apply_discriminators(schema)
|
||||
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
|
||||
self.ns_resolver = NsResolver(
|
||||
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
|
||||
)
|
||||
self.config_wrapper = ConfigWrapper(config)
|
||||
if not self.config_wrapper.defer_build:
|
||||
self._create_validators()
|
||||
else:
|
||||
self.__pydantic_complete__ = False
|
||||
|
||||
if self._validate_return:
|
||||
return_type = (
|
||||
self.__signature__.return_annotation
|
||||
if self.__signature__.return_annotation is not self.__signature__.empty
|
||||
else Any
|
||||
def _create_validators(self) -> None:
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
|
||||
core_config = self.config_wrapper.core_config(title=self.qualname)
|
||||
|
||||
self.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
self.config_wrapper.plugin_settings,
|
||||
)
|
||||
if self.validate_return:
|
||||
signature = inspect.signature(self.function)
|
||||
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
|
||||
validator = create_schema_validator(
|
||||
schema,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
self.config_wrapper.plugin_settings,
|
||||
)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type))
|
||||
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
|
||||
self.__return_pydantic_core_schema__ = schema
|
||||
core_config = config_wrapper.core_config(self)
|
||||
schema = validate_core_schema(schema)
|
||||
validator = pydantic_core.SchemaValidator(schema, core_config)
|
||||
if inspect.iscoroutinefunction(self.raw_function):
|
||||
if inspect.iscoroutinefunction(self.function):
|
||||
|
||||
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
|
||||
return validator.validate_python(await aw)
|
||||
@@ -91,38 +125,16 @@ class ValidateCallWrapper:
|
||||
else:
|
||||
self.__return_pydantic_validator__ = validator.validate_python
|
||||
else:
|
||||
self.__return_pydantic_core_schema__ = None
|
||||
self.__return_pydantic_validator__ = None
|
||||
|
||||
self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods
|
||||
self.__pydantic_complete__ = True
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not self.__pydantic_complete__:
|
||||
self._create_validators()
|
||||
|
||||
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
|
||||
if self.__return_pydantic_validator__:
|
||||
return self.__return_pydantic_validator__(res)
|
||||
return res
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
|
||||
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
|
||||
if obj is None:
|
||||
try:
|
||||
# Handle the case where a method is accessed as a class attribute
|
||||
return objtype.__getattribute__(objtype, self._name) # type: ignore
|
||||
except AttributeError:
|
||||
# This will happen the first time the attribute is accessed
|
||||
pass
|
||||
|
||||
bound_function = self.raw_function.__get__(obj, objtype)
|
||||
result = self.__class__(bound_function, self._config, self._validate_return)
|
||||
if self._name is not None:
|
||||
if obj is not None:
|
||||
object.__setattr__(obj, self._name, result)
|
||||
else:
|
||||
object.__setattr__(objtype, self._name, result)
|
||||
return result
|
||||
|
||||
def __set_name__(self, owner: Any, name: str) -> None:
|
||||
self._name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'ValidateCallWrapper({self.raw_function})'
|
||||
else:
|
||||
return res
|
||||
|
||||
@@ -5,22 +5,32 @@ Import of this module is deferred since it contains imports of many standard lib
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Union, cast, get_origin
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import PydanticCustomError, core_schema
|
||||
from pydantic_core._pydantic_core import PydanticKnownError
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from pydantic._internal._import_utils import import_cached_field_info
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
|
||||
|
||||
def sequence_validator(
|
||||
__input_value: typing.Sequence[Any],
|
||||
input_value: typing.Sequence[Any],
|
||||
/,
|
||||
validator: core_schema.ValidatorFunctionWrapHandler,
|
||||
) -> typing.Sequence[Any]:
|
||||
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
|
||||
value_type = type(__input_value)
|
||||
value_type = type(input_value)
|
||||
|
||||
# We don't accept any plain string as a sequence
|
||||
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
|
||||
@@ -31,14 +41,24 @@ def sequence_validator(
|
||||
{'type_name': value_type.__name__},
|
||||
)
|
||||
|
||||
v_list = validator(__input_value)
|
||||
# TODO: refactor sequence validation to validate with either a list or a tuple
|
||||
# schema, depending on the type of the value.
|
||||
# Additionally, we should be able to remove one of either this validator or the
|
||||
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
|
||||
# Effectively, a refactor for sequence validation is needed.
|
||||
if value_type is tuple:
|
||||
input_value = list(input_value)
|
||||
|
||||
v_list = validator(input_value)
|
||||
|
||||
# the rest of the logic is just re-creating the original type from `v_list`
|
||||
if value_type == list:
|
||||
if value_type is list:
|
||||
return v_list
|
||||
elif issubclass(value_type, range):
|
||||
# return the list as we probably can't re-create the range
|
||||
return v_list
|
||||
elif value_type is tuple:
|
||||
return tuple(v_list)
|
||||
else:
|
||||
# best guess at how to re-create the original type, more custom construction logic might be required
|
||||
return value_type(v_list) # type: ignore[call-arg]
|
||||
@@ -49,7 +69,7 @@ def import_string(value: Any) -> Any:
|
||||
try:
|
||||
return _import_string_logic(value)
|
||||
except ImportError as e:
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)})
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
|
||||
else:
|
||||
# otherwise we just return the value and let the next validator do the rest of the work
|
||||
return value
|
||||
@@ -106,39 +126,39 @@ def _import_string_logic(dotted_path: str) -> Any:
|
||||
return module
|
||||
|
||||
|
||||
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
return __input_value
|
||||
elif isinstance(__input_value, (str, bytes)):
|
||||
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
return input_value
|
||||
elif isinstance(input_value, (str, bytes)):
|
||||
# todo strict mode
|
||||
return compile_pattern(__input_value) # type: ignore
|
||||
return compile_pattern(input_value) # type: ignore
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, str):
|
||||
return __input_value
|
||||
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, str):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
elif isinstance(__input_value, str):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, bytes):
|
||||
elif isinstance(input_value, str):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, bytes):
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, bytes):
|
||||
return __input_value
|
||||
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, bytes):
|
||||
return input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
elif isinstance(__input_value, bytes):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, str):
|
||||
elif isinstance(input_value, bytes):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, str):
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
@@ -154,125 +174,359 @@ def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
|
||||
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
|
||||
|
||||
|
||||
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
|
||||
if isinstance(__input_value, IPv4Address):
|
||||
return __input_value
|
||||
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
|
||||
if isinstance(input_value, IPv4Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Address(__input_value)
|
||||
return IPv4Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
|
||||
|
||||
|
||||
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
|
||||
if isinstance(__input_value, IPv6Address):
|
||||
return __input_value
|
||||
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
|
||||
if isinstance(input_value, IPv6Address):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Address(__input_value)
|
||||
return IPv6Address(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
|
||||
|
||||
|
||||
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
|
||||
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
|
||||
"""Assume IPv4Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(__input_value, IPv4Network):
|
||||
return __input_value
|
||||
if isinstance(input_value, IPv4Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Network(__input_value)
|
||||
return IPv4Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
|
||||
|
||||
|
||||
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
|
||||
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
|
||||
"""Assume IPv6Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(__input_value, IPv6Network):
|
||||
return __input_value
|
||||
if isinstance(input_value, IPv6Network):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Network(__input_value)
|
||||
return IPv6Network(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
|
||||
|
||||
|
||||
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
|
||||
if isinstance(__input_value, IPv4Interface):
|
||||
return __input_value
|
||||
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
|
||||
if isinstance(input_value, IPv4Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv4Interface(__input_value)
|
||||
return IPv4Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
|
||||
|
||||
|
||||
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
|
||||
if isinstance(__input_value, IPv6Interface):
|
||||
return __input_value
|
||||
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
|
||||
if isinstance(input_value, IPv6Interface):
|
||||
return input_value
|
||||
|
||||
try:
|
||||
return IPv6Interface(__input_value)
|
||||
return IPv6Interface(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': gt})
|
||||
return x
|
||||
def fraction_validator(input_value: Any, /) -> Fraction:
|
||||
if isinstance(input_value, Fraction):
|
||||
return input_value
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': ge})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': lt})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': le})
|
||||
return x
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
if not (x % multiple_of == 0):
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
|
||||
return x
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short',
|
||||
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
try:
|
||||
return Fraction(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
|
||||
|
||||
|
||||
def forbid_inf_nan_check(x: Any) -> Any:
|
||||
if not math.isfinite(x):
|
||||
raise PydanticKnownError('finite_number')
|
||||
return x
|
||||
|
||||
|
||||
def _safe_repr(v: Any) -> int | float | str:
|
||||
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
|
||||
|
||||
See tests/test_types.py::test_annotated_metadata_any_order for some context.
|
||||
"""
|
||||
if isinstance(v, (int, float, str)):
|
||||
return v
|
||||
return repr(v)
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
try:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
try:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
try:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
try:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
try:
|
||||
if x % multiple_of:
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
try:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
try:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
|
||||
|
||||
|
||||
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
|
||||
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
|
||||
|
||||
This function handles both normalized and non-normalized Decimal instances.
|
||||
Example: Decimal('1.230') -> 4 digits, 3 decimal places
|
||||
|
||||
Args:
|
||||
decimal (Decimal): The decimal number to analyze.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: A tuple containing the number of decimal places and total digits.
|
||||
|
||||
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
|
||||
of the number of decimals and digits together.
|
||||
"""
|
||||
try:
|
||||
decimal_tuple = decimal.as_tuple()
|
||||
|
||||
assert isinstance(decimal_tuple.exponent, int)
|
||||
|
||||
exponent = decimal_tuple.exponent
|
||||
num_digits = len(decimal_tuple.digits)
|
||||
|
||||
if exponent >= 0:
|
||||
# A positive exponent adds that many trailing zeros
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
|
||||
num_digits += exponent
|
||||
decimal_places = 0
|
||||
else:
|
||||
# If the absolute value of the negative exponent is larger than the
|
||||
# number of digits, then it's the same as the number of digits,
|
||||
# because it'll consume all the digits in digit_tuple and then
|
||||
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
|
||||
decimal_places = abs(exponent)
|
||||
num_digits = max(num_digits, decimal_places)
|
||||
|
||||
return decimal_places, num_digits
|
||||
except (AssertionError, AttributeError):
|
||||
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
|
||||
|
||||
|
||||
def max_digits_validator(x: Any, max_digits: Any) -> Any:
|
||||
try:
|
||||
_, num_digits = _extract_decimal_digits_info(x)
|
||||
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
|
||||
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_digits',
|
||||
{'max_digits': max_digits},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
|
||||
|
||||
|
||||
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
|
||||
try:
|
||||
decimal_places_, _ = _extract_decimal_digits_info(x)
|
||||
if decimal_places_ > decimal_places:
|
||||
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
|
||||
if normalized_decimal_places > decimal_places:
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_places',
|
||||
{'decimal_places': decimal_places},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
|
||||
|
||||
|
||||
def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
|
||||
return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
values_type_origin = get_origin(values_source_type)
|
||||
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type = values_type_origin or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if typing_objects.is_typevar(values_type):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if typing_objects.is_annotated(values_type_origin):
|
||||
field_info = next((v for v in typing_extensions.get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
# Assume the default factory does not take any argument:
|
||||
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
|
||||
if isinstance(value, ZoneInfo):
|
||||
return value
|
||||
try:
|
||||
return ZoneInfo(value)
|
||||
except (ZoneInfoNotFoundError, ValueError, TypeError):
|
||||
raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
|
||||
|
||||
|
||||
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
|
||||
'gt': greater_than_validator,
|
||||
'ge': greater_than_or_equal_validator,
|
||||
'lt': less_than_validator,
|
||||
'le': less_than_or_equal_validator,
|
||||
'multiple_of': multiple_of_validator,
|
||||
'min_length': min_length_validator,
|
||||
'max_length': max_length_validator,
|
||||
'max_digits': max_digits_validator,
|
||||
'decimal_places': decimal_places_validator,
|
||||
}
|
||||
|
||||
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
|
||||
|
||||
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
|
||||
IPv4Address: ip_v4_address_validator,
|
||||
IPv6Address: ip_v6_address_validator,
|
||||
IPv4Network: ip_v4_network_validator,
|
||||
IPv6Network: ip_v6_network_validator,
|
||||
IPv4Interface: ip_v4_interface_validator,
|
||||
IPv6Interface: ip_v6_interface_validator,
|
||||
}
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict, # noqa: UP006
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
typing.OrderedDict: collections.OrderedDict, # noqa: UP006
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
typing.Counter: collections.Counter,
|
||||
collections.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user